From b50a7d32921295e8274dc1acca59e506355a6cd7 Mon Sep 17 00:00:00 2001 From: Ian Shim Date: Fri, 5 Jan 2024 01:17:23 -0800 Subject: [PATCH] allowlist with specified rate limits --- common/ratelimit/limiter.go | 18 +---- common/ratelimit/limiter_cli.go | 10 --- common/ratelimit/ratelimit_test.go | 18 +---- disperser/apiserver/rate_config.go | 51 ++++++++++++ disperser/apiserver/server.go | 80 +++++++++++++++---- disperser/apiserver/server_test.go | 124 ++++++++++++++++++++++++++++- disperser/cmd/apiserver/main.go | 2 +- node/cmd/main.go | 2 +- 8 files changed, 241 insertions(+), 64 deletions(-) diff --git a/common/ratelimit/limiter.go b/common/ratelimit/limiter.go index 3a526c28eb..30bd578438 100644 --- a/common/ratelimit/limiter.go +++ b/common/ratelimit/limiter.go @@ -2,7 +2,6 @@ package ratelimit import ( "context" - "strings" "time" "github.com/Layr-Labs/eigenda/common" @@ -12,35 +11,23 @@ type BucketStore = common.KVStore[common.RateBucketParams] type rateLimiter struct { globalRateParams common.GlobalRateParams - - bucketStore BucketStore - allowlist []string + bucketStore BucketStore logger common.Logger } -func NewRateLimiter(rateParams common.GlobalRateParams, bucketStore BucketStore, allowlist []string, logger common.Logger) common.RateLimiter { +func NewRateLimiter(rateParams common.GlobalRateParams, bucketStore BucketStore, logger common.Logger) common.RateLimiter { return &rateLimiter{ globalRateParams: rateParams, bucketStore: bucketStore, - allowlist: allowlist, logger: logger, } } // Checks whether a request from the given requesterID is allowed func (d *rateLimiter) AllowRequest(ctx context.Context, requesterID common.RequesterID, blobSize uint, rate common.RateParam) (bool, error) { - // TODO: temporary allowlist that unconditionally allows request - // for testing purposes only - for _, id := range d.allowlist { - if strings.Contains(requesterID, id) { - return true, nil - } - } - // Retrieve bucket params for the requester ID // This will be from dynamo for Disperser and from local storage for DA node - bucketParams, err := d.bucketStore.GetItem(ctx, requesterID) if err != nil { @@ -68,7 +55,6 @@ func (d *rateLimiter) AllowRequest(ctx context.Context, requesterID common.Reque // Update the bucket level bucketParams.BucketLevels[i] = getBucketLevel(bucketParams.BucketLevels[i], size, interval, deduction) - allowed = allowed && bucketParams.BucketLevels[i] > 0 } diff --git a/common/ratelimit/limiter_cli.go b/common/ratelimit/limiter_cli.go index 91aa310940..86b8794bc6 100644 --- a/common/ratelimit/limiter_cli.go +++ b/common/ratelimit/limiter_cli.go @@ -15,14 +15,12 @@ const ( BucketMultipliersFlagName = "bucket-multipliers" CountFailedFlagName = "count-failed" BucketStoreSizeFlagName = "bucket-store-size" - AllowlistFlagName = "allowlist" ) type Config struct { common.GlobalRateParams BucketStoreSize int UniformRateParam common.RateParam - Allowlist []string } func RatelimiterCLIFlags(envPrefix string, flagPrefix string) []cli.Flag { @@ -54,13 +52,6 @@ func RatelimiterCLIFlags(envPrefix string, flagPrefix string) []cli.Flag { EnvVar: common.PrefixEnvVar(envPrefix, "BUCKET_STORE_SIZE"), Required: false, }, - cli.StringSliceFlag{ - Name: common.PrefixFlag(flagPrefix, AllowlistFlagName), - Usage: "Allowlist of IPs to bypass rate limiting", - EnvVar: common.PrefixEnvVar(envPrefix, "ALLOWLIST"), - Required: false, - Value: &cli.StringSlice{}, - }, } } @@ -106,7 +97,6 @@ func ReadCLIConfig(ctx *cli.Context, flagPrefix string) (Config, error) { cfg.Multipliers = multipliers cfg.GlobalRateParams.CountFailed = ctx.Bool(common.PrefixFlag(flagPrefix, CountFailedFlagName)) cfg.BucketStoreSize = ctx.Int(common.PrefixFlag(flagPrefix, BucketStoreSizeFlagName)) - cfg.Allowlist = ctx.StringSlice(common.PrefixFlag(flagPrefix, AllowlistFlagName)) err := validateConfig(cfg) if err != nil { diff --git a/common/ratelimit/ratelimit_test.go b/common/ratelimit/ratelimit_test.go index bf58a4c0c2..e74c0da8c3 100644 --- a/common/ratelimit/ratelimit_test.go +++ b/common/ratelimit/ratelimit_test.go @@ -25,7 +25,7 @@ func makeTestRatelimiter() (common.RateLimiter, error) { return nil, err } - ratelimiter := ratelimit.NewRateLimiter(globalParams, bucketStore, []string{"testRetriever2"}, &mock.Logger{}) + ratelimiter := ratelimit.NewRateLimiter(globalParams, bucketStore, &mock.Logger{}) return ratelimiter, nil @@ -50,19 +50,3 @@ func TestRatelimit(t *testing.T) { assert.NoError(t, err) assert.Equal(t, false, allow) } - -func TestRatelimitAllowlist(t *testing.T) { - ratelimiter, err := makeTestRatelimiter() - assert.NoError(t, err) - - ctx := context.Background() - - retreiverID := "testRetriever2" - - // 10x more requests allowed for allowlisted IDs - for i := 0; i < 100; i++ { - allow, err := ratelimiter.AllowRequest(ctx, retreiverID, 10, 100) - assert.NoError(t, err) - assert.Equal(t, true, allow) - } -} diff --git a/disperser/apiserver/rate_config.go b/disperser/apiserver/rate_config.go index 3ad427d3ee..ebe3868529 100644 --- a/disperser/apiserver/rate_config.go +++ b/disperser/apiserver/rate_config.go @@ -2,7 +2,9 @@ package apiserver import ( "fmt" + "log" "strconv" + "strings" "github.com/Layr-Labs/eigenda/common" "github.com/Layr-Labs/eigenda/core" @@ -16,6 +18,7 @@ const ( TotalUnauthBlobRateFlagName = "auth.total-unauth-blob-rate" PerUserUnauthBlobRateFlagName = "auth.per-user-unauth-blob-rate" ClientIPHeaderFlagName = "auth.client-ip-header" + AllowlistFlagName = "auth.allowlist" // We allow the user to specify the blob rate in blobs/sec, but internally we use blobs/sec * 1e6 (i.e. blobs/microsec). // This is because the rate limiter takes an integer rate. @@ -29,9 +32,17 @@ type QuorumRateInfo struct { TotalUnauthBlobRate common.RateParam } +type PerUserRateInfo struct { + Throughput common.RateParam + BlobRate common.RateParam +} + +type Allowlist = map[string]map[core.QuorumID]PerUserRateInfo + type RateConfig struct { QuorumRateInfos map[core.QuorumID]QuorumRateInfo ClientIPHeader string + Allowlist Allowlist } func CLIFlags(envPrefix string) []cli.Flag { @@ -73,6 +84,13 @@ func CLIFlags(envPrefix string) []cli.Flag { Value: "", EnvVar: common.PrefixEnvVar(envPrefix, "CLIENT_IP_HEADER"), }, + cli.StringSliceFlag{ + Name: AllowlistFlagName, + Usage: "Allowlist of IPs and corresponding blob/byte rates to bypass rate limiting. Format: :::. Example: 127.0.0.1:0:10:10485760", + EnvVar: common.PrefixEnvVar(envPrefix, "ALLOWLIST"), + Required: false, + Value: &cli.StringSlice{}, + }, } } @@ -112,8 +130,41 @@ func ReadCLIConfig(c *cli.Context) (RateConfig, error) { } } + // Parse allowlist + allowlist := make(Allowlist) + for _, allowlistEntry := range c.StringSlice(AllowlistFlagName) { + allowlistEntrySplit := strings.Split(allowlistEntry, ":") + if len(allowlistEntrySplit) != 4 { + log.Printf("invalid allowlist entry: entry should contain exactly 4 elements: %s", allowlistEntry) + continue + } + ip := allowlistEntrySplit[0] + quorumID, err := strconv.Atoi(allowlistEntrySplit[1]) + if err != nil { + log.Printf("invalid allowlist entry: failed to convert quorum ID from string: %s", allowlistEntry) + continue + } + blobRate, err := strconv.ParseFloat(allowlistEntrySplit[2], 64) + if err != nil { + log.Printf("invalid allowlist entry: failed to convert blob rate from string: %s", allowlistEntry) + continue + } + byteRate, err := strconv.ParseFloat(allowlistEntrySplit[3], 64) + if err != nil { + log.Printf("invalid allowlist entry: failed to convert throughput from string: %s", allowlistEntry) + continue + } + allowlist[ip] = map[core.QuorumID]PerUserRateInfo{ + core.QuorumID(quorumID): { + Throughput: common.RateParam(byteRate), + BlobRate: common.RateParam(blobRate * blobRateMultiplier), + }, + } + } + return RateConfig{ QuorumRateInfos: quorumRateInfos, ClientIPHeader: c.String(ClientIPHeaderFlagName), + Allowlist: allowlist, }, nil } diff --git a/disperser/apiserver/server.go b/disperser/apiserver/server.go index 42dcd7c028..f315e11c95 100644 --- a/disperser/apiserver/server.go +++ b/disperser/apiserver/server.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "strings" "sync" "time" @@ -18,8 +19,10 @@ import ( "google.golang.org/grpc/reflection" ) -var errSystemRateLimit = fmt.Errorf("request ratelimited: system limit") -var errAccountRateLimit = fmt.Errorf("request ratelimited: account limit") +var errSystemBlobRateLimit = fmt.Errorf("request ratelimited: system blob limit") +var errSystemThroughputRateLimit = fmt.Errorf("request ratelimited: system throughput limit") +var errAccountBlobRateLimit = fmt.Errorf("request ratelimited: account blob limit") +var errAccountThroughputRateLimit = fmt.Errorf("request ratelimited: account throughput limit") const systemAccountKey = "system" @@ -55,6 +58,11 @@ func NewDispersalServer( ratelimiter common.RateLimiter, rateConfig RateConfig, ) *DispersalServer { + for ip, rateInfoByQuorum := range rateConfig.Allowlist { + for quorumID, rateInfo := range rateInfoByQuorum { + logger.Info("[Allowlist]", "ip", ip, "quorumID", quorumID, "throughput", rateInfo.Throughput, "blobRate", rateInfo.BlobRate) + } + } return &DispersalServer{ config: config, blobStore: store, @@ -139,9 +147,9 @@ func (s *DispersalServer) DisperseBlob(ctx context.Context, req *pb.DisperseBlob if err != nil { for _, param := range securityParams { quorumId := string(uint8(param.GetQuorumId())) - if errors.Is(err, errSystemRateLimit) { + if errors.Is(err, errSystemBlobRateLimit) || errors.Is(err, errSystemThroughputRateLimit) { s.metrics.HandleSystemRateLimitedRequest(quorumId, blobSize, "DisperseBlob") - } else if errors.Is(err, errAccountRateLimit) { + } else if errors.Is(err, errAccountBlobRateLimit) || errors.Is(err, errAccountThroughputRateLimit) { s.metrics.HandleAccountRateLimitedRequest(quorumId, blobSize, "DisperseBlob") } else { s.metrics.HandleFailedRequest(quorumId, blobSize, "DisperseBlob") @@ -173,23 +181,65 @@ func (s *DispersalServer) DisperseBlob(ctx context.Context, req *pb.DisperseBlob }, nil } +func (s *DispersalServer) getAccountRate(origin string, quorumID core.QuorumID) (*PerUserRateInfo, error) { + unauthRates, ok := s.rateConfig.QuorumRateInfos[quorumID] + if !ok { + return nil, fmt.Errorf("no configured rate exists for quorum %d", quorumID) + } + + for ip, rateInfoByQuorum := range s.rateConfig.Allowlist { + if !strings.Contains(origin, ip) { + continue + } + + rateInfo, ok := rateInfoByQuorum[quorumID] + if !ok { + continue + } + + throughput := unauthRates.PerUserUnauthThroughput + if rateInfo.Throughput > 0 { + throughput = rateInfo.Throughput + } + + blobRate := unauthRates.PerUserUnauthBlobRate + if rateInfo.BlobRate > 0 { + blobRate = rateInfo.BlobRate + } + + return &PerUserRateInfo{ + Throughput: throughput, + BlobRate: blobRate, + }, nil + } + + return &PerUserRateInfo{ + Throughput: unauthRates.PerUserUnauthThroughput, + BlobRate: unauthRates.PerUserUnauthBlobRate, + }, nil +} + func (s *DispersalServer) checkRateLimitsAndAddRates(ctx context.Context, blob *core.Blob, origin string) error { // TODO(robert): Remove these locks once we have resolved ratelimiting approach s.mu.Lock() defer s.mu.Unlock() - for _, param := range blob.RequestHeader.SecurityParams { + for i, param := range blob.RequestHeader.SecurityParams { rates, ok := s.rateConfig.QuorumRateInfos[param.QuorumID] if !ok { return fmt.Errorf("no configured rate exists for quorum %d", param.QuorumID) } + accountRates, err := s.getAccountRate(origin, param.QuorumID) + if err != nil { + return err + } // Get the encoded blob size from the blob header. Calculation is done in a way that nodes can replicate blobSize := len(blob.Data) length := core.GetBlobLength(uint(blobSize)) - encodedLength := core.GetEncodedBlobLength(length, uint8(blob.RequestHeader.SecurityParams[param.QuorumID].QuorumThreshold), uint8(blob.RequestHeader.SecurityParams[param.QuorumID].AdversaryThreshold)) + encodedLength := core.GetEncodedBlobLength(length, uint8(param.QuorumThreshold), uint8(param.AdversaryThreshold)) encodedSize := core.GetBlobSize(encodedLength) s.logger.Debug("checking rate limits", "origin", origin, "quorum", param.QuorumID, "encodedSize", encodedSize, "blobSize", blobSize) @@ -202,7 +252,7 @@ func (s *DispersalServer) checkRateLimitsAndAddRates(ctx context.Context, blob * } if !allowed { s.logger.Warn("system byte ratelimit exceeded", "systemQuorumKey", systemQuorumKey, "rate", rates.TotalUnauthThroughput) - return errSystemRateLimit + return errSystemThroughputRateLimit } systemQuorumKey = fmt.Sprintf("%s:%d-blobrate", systemAccountKey, param.QuorumID) @@ -212,7 +262,7 @@ func (s *DispersalServer) checkRateLimitsAndAddRates(ctx context.Context, blob * } if !allowed { s.logger.Warn("system blob ratelimit exceeded", "systemQuorumKey", systemQuorumKey, "rate", float32(rates.TotalUnauthBlobRate)/blobRateMultiplier) - return errSystemRateLimit + return errSystemBlobRateLimit } // Check Account Ratelimit @@ -220,27 +270,27 @@ func (s *DispersalServer) checkRateLimitsAndAddRates(ctx context.Context, blob * blob.RequestHeader.AccountID = "ip:" + origin userQuorumKey := fmt.Sprintf("%s:%d", blob.RequestHeader.AccountID, param.QuorumID) - allowed, err = s.ratelimiter.AllowRequest(ctx, userQuorumKey, encodedSize, rates.PerUserUnauthThroughput) + allowed, err = s.ratelimiter.AllowRequest(ctx, userQuorumKey, encodedSize, accountRates.Throughput) if err != nil { return fmt.Errorf("ratelimiter error: %v", err) } if !allowed { - s.logger.Warn("account byte ratelimit exceeded", "userQuorumKey", userQuorumKey, "rate", rates.PerUserUnauthThroughput) - return errAccountRateLimit + s.logger.Warn("account byte ratelimit exceeded", "userQuorumKey", userQuorumKey, "rate", accountRates.Throughput) + return errAccountThroughputRateLimit } userQuorumKey = fmt.Sprintf("%s:%d-blobrate", blob.RequestHeader.AccountID, param.QuorumID) - allowed, err = s.ratelimiter.AllowRequest(ctx, userQuorumKey, blobRateMultiplier, rates.PerUserUnauthBlobRate) + allowed, err = s.ratelimiter.AllowRequest(ctx, userQuorumKey, blobRateMultiplier, accountRates.BlobRate) if err != nil { return fmt.Errorf("ratelimiter error: %v", err) } if !allowed { - s.logger.Warn("account blob ratelimit exceeded", "userQuorumKey", userQuorumKey, "rate", float32(rates.PerUserUnauthBlobRate)/blobRateMultiplier) - return errAccountRateLimit + s.logger.Warn("account blob ratelimit exceeded", "userQuorumKey", userQuorumKey, "rate", float32(accountRates.BlobRate)/blobRateMultiplier) + return errAccountBlobRateLimit } // Update the quorum rate - blob.RequestHeader.SecurityParams[param.QuorumID].QuorumRate = rates.PerUserUnauthThroughput + blob.RequestHeader.SecurityParams[i].QuorumRate = accountRates.Throughput } return nil diff --git a/disperser/apiserver/server_test.go b/disperser/apiserver/server_test.go index 318e055cda..3df211b0bb 100644 --- a/disperser/apiserver/server_test.go +++ b/disperser/apiserver/server_test.go @@ -20,6 +20,8 @@ import ( "github.com/Layr-Labs/eigenda/common/aws/dynamodb" "github.com/Layr-Labs/eigenda/common/aws/s3" "github.com/Layr-Labs/eigenda/common/logging" + "github.com/Layr-Labs/eigenda/common/ratelimit" + "github.com/Layr-Labs/eigenda/common/store" "github.com/Layr-Labs/eigenda/core" "github.com/Layr-Labs/eigenda/core/mock" "github.com/Layr-Labs/eigenda/disperser" @@ -52,7 +54,7 @@ func TestMain(m *testing.M) { } func TestDisperseBlob(t *testing.T) { - data := make([]byte, 1024*1024) + data := make([]byte, 3*1024) _, err := rand.Read(data) assert.NoError(t, err) @@ -277,6 +279,90 @@ func TestDisperseBlobWithExceedSizeLimit(t *testing.T) { assert.Equal(t, err.Error(), "blob size cannot exceed 2 MiB") } +func TestRatelimit(t *testing.T) { + data50KiB := make([]byte, 50*1024) + _, err := rand.Read(data50KiB) + assert.NoError(t, err) + data1KiB := make([]byte, 1024) + _, err = rand.Read(data1KiB) + assert.NoError(t, err) + + // Try with a non-allowlisted IP + p := &peer.Peer{ + Addr: &net.TCPAddr{ + IP: net.ParseIP("0.0.0.0"), + Port: 51001, + }, + } + ctx := peer.NewContext(context.Background(), p) + + // Try with non-allowlisted IP + // Should fail with account throughput limit because unauth throughput limit is 20 KiB/s for quorum 0 + _, err = dispersalServer.DisperseBlob(ctx, &pb.DisperseBlobRequest{ + Data: data50KiB, + SecurityParams: []*pb.SecurityParams{ + { + QuorumId: 0, + AdversaryThreshold: 50, + QuorumThreshold: 100, + }, + }, + }) + assert.ErrorContains(t, err, "account throughput limit") + + // Try with non-allowlisted IP. Should fail with account blob limit because blob rate (3 blobs/s) X bucket size (3s) is smaller than 10 blobs. + for i := 0; i < 10; i++ { + _, err = dispersalServer.DisperseBlob(ctx, &pb.DisperseBlobRequest{ + Data: data1KiB, + SecurityParams: []*pb.SecurityParams{ + { + QuorumId: 1, + AdversaryThreshold: 50, + QuorumThreshold: 100, + }, + }, + }) + } + assert.ErrorContains(t, err, "account blob limit") + + // Now try with an allowlisted IP + // This should succeed because the account throughput limit is 100 KiB/s for quorum 0 + p = &peer.Peer{ + Addr: &net.TCPAddr{ + IP: net.ParseIP("1.2.3.4"), + Port: 51001, + }, + } + ctx = peer.NewContext(context.Background(), p) + + _, err = dispersalServer.DisperseBlob(ctx, &pb.DisperseBlobRequest{ + Data: data50KiB, + SecurityParams: []*pb.SecurityParams{ + { + QuorumId: 0, + AdversaryThreshold: 50, + QuorumThreshold: 100, + }, + }, + }) + assert.NoError(t, err) + + // This should succeed because the account blob limit (5 blobs/s) X bucket size (3s) is larger than 10 blobs. + for i := 0; i < 10; i++ { + _, err = dispersalServer.DisperseBlob(ctx, &pb.DisperseBlobRequest{ + Data: data1KiB, + SecurityParams: []*pb.SecurityParams{ + { + QuorumId: 1, + AdversaryThreshold: 50, + QuorumThreshold: 100, + }, + }, + }) + assert.NoError(t, err) + } +} + func setup(m *testing.M) { deployLocalStack = !(os.Getenv("DEPLOY_LOCALSTACK") == "false") @@ -333,15 +419,45 @@ func newTestServer(m *testing.M) *apiserver.DispersalServer { } blobMetadataStore := blobstore.NewBlobMetadataStore(dynamoClient, logger, metadataTableName, time.Hour) - var ratelimiter common.RateLimiter + globalParams := common.GlobalRateParams{ + CountFailed: false, + BucketSizes: []time.Duration{3 * time.Second}, + Multipliers: []float32{1}, + } + bucketStore, err := store.NewLocalParamStore[common.RateBucketParams](1000) + if err != nil { + panic("failed to create bucket store") + } + ratelimiter := ratelimit.NewRateLimiter(globalParams, bucketStore, logger) + rateConfig := apiserver.RateConfig{ QuorumRateInfos: map[core.QuorumID]apiserver.QuorumRateInfo{ 0: { - PerUserUnauthThroughput: 0, - TotalUnauthThroughput: 0, + PerUserUnauthThroughput: 20 * 1024, + TotalUnauthThroughput: 1048576, + PerUserUnauthBlobRate: 3 * 1e6, + TotalUnauthBlobRate: 10 * 1e6, + }, + 1: { + PerUserUnauthThroughput: 20 * 1024, + TotalUnauthThroughput: 1048576, + PerUserUnauthBlobRate: 3 * 1e6, + TotalUnauthBlobRate: 10 * 1e6, }, }, ClientIPHeader: "", + Allowlist: apiserver.Allowlist{ + "1.2.3.4": map[uint8]apiserver.PerUserRateInfo{ + 0: { + Throughput: 100 * 1024, + BlobRate: 5 * 1e6, + }, + 1: { + Throughput: 1024 * 1024, + BlobRate: 5 * 1e6, + }, + }, + }, } queue = blobstore.NewSharedStorage(bucketName, s3Client, blobMetadataStore, logger) diff --git a/disperser/cmd/apiserver/main.go b/disperser/cmd/apiserver/main.go index 082aedbb13..84c1a0ffca 100644 --- a/disperser/cmd/apiserver/main.go +++ b/disperser/cmd/apiserver/main.go @@ -109,7 +109,7 @@ func RunDisperserServer(ctx *cli.Context) error { return err } } - ratelimiter = ratelimit.NewRateLimiter(globalParams, bucketStore, config.RatelimiterConfig.Allowlist, logger) + ratelimiter = ratelimit.NewRateLimiter(globalParams, bucketStore, logger) } // TODO: create a separate metrics for batcher diff --git a/node/cmd/main.go b/node/cmd/main.go index f0f86d2d39..de58935dc5 100644 --- a/node/cmd/main.go +++ b/node/cmd/main.go @@ -80,7 +80,7 @@ func NodeMain(ctx *cli.Context) error { return err } - ratelimiter := ratelimit.NewRateLimiter(globalParams, bucketStore, []string{}, logger) + ratelimiter := ratelimit.NewRateLimiter(globalParams, bucketStore, logger) // Creates the GRPC server. server := grpc.NewServer(config, node, logger, ratelimiter)