diff --git a/admin/commands/collection/tx_rate_limiter.go b/admin/commands/collection/tx_rate_limiter.go index 0bb39df80a1..da953747955 100644 --- a/admin/commands/collection/tx_rate_limiter.go +++ b/admin/commands/collection/tx_rate_limiter.go @@ -3,7 +3,6 @@ package collection import ( "context" "fmt" - "strconv" "github.com/onflow/flow-go/admin" "github.com/onflow/flow-go/admin/commands" @@ -25,14 +24,19 @@ func NewTxRateLimitCommand(limiter *ingest.AddressRateLimiter) *TxRateLimitComma } func (s *TxRateLimitCommand) Handler(_ context.Context, req *admin.CommandRequest) (interface{}, error) { - input, ok := req.Data.(map[string]string) + input, ok := req.Data.(map[string]interface{}) if !ok { return admin.NewInvalidAdminReqFormatError("expected { \"command\": \"add|remove|get|get_config|set_config\", \"addresses\": \"addresses\""), nil } - cmd, ok := input["command"] + command, ok := input["command"] if !ok { - return admin.NewInvalidAdminReqErrorf("the \"command\" field is empty, must be either \"add\" or \"remove\" or \"get\""), nil + return admin.NewInvalidAdminReqErrorf("the \"command\" field is empty, must be one of add|remove|get|get_config|set_config"), nil + } + + cmd, ok := command.(string) + if !ok { + return admin.NewInvalidAdminReqErrorf("the \"command\" field is not string, must be one of add|remove|get|get_config|set_config"), nil } if cmd == "get" { @@ -40,10 +44,14 @@ func (s *TxRateLimitCommand) Handler(_ context.Context, req *admin.CommandReques } if cmd == "add" || cmd == "remove" { - addresses, ok := input["addresses"] + result, ok := input["addresses"] if !ok { return admin.NewInvalidAdminReqErrorf("the \"addresses\" field is empty, must be hex formated addresses, can be splitted by \",\""), nil } + addresses, ok := result.(string) + if !ok { + return admin.NewInvalidAdminReqErrorf("the \"addresses\" field is not string, must be hex formated addresses, can be splitted by \",\""), nil + } resp, err := s.AddOrRemove(cmd, addresses) if err != nil { @@ -58,19 +66,19 @@ func (s *TxRateLimitCommand) Handler(_ context.Context, req *admin.CommandReques } if cmd == "set_config" { - strLimit, limit_ok := input["limit"] - strBurst, burst_ok := input["burst"] + dataLimit, limit_ok := input["limit"] + dataBurst, burst_ok := input["burst"] if !limit_ok || !burst_ok { return admin.NewInvalidAdminReqErrorf("the \"limit\" or \"burst\" field is empty, must be number"), nil } - limit, err := strconv.ParseFloat(strLimit, 64) - if err == nil { - return admin.NewInvalidAdminReqErrorf("the \"limit\" field is not number: %v", strLimit), nil + limit, ok := dataLimit.(float64) + if !ok { + return admin.NewInvalidAdminReqErrorf("the \"limit\" field is not number: %v", dataLimit), nil } - burst, err := strconv.Atoi(strBurst) - if err == nil { - return admin.NewInvalidAdminReqErrorf("the \"burst\" field is not number: %v", strBurst), nil + burst, ok := dataBurst.(int) + if !ok { + return admin.NewInvalidAdminReqErrorf("the \"burst\" field is not number: %v", dataBurst), nil } s.limiter.SetLimitConfig(rate.Limit(limit), burst) diff --git a/engine/collection/ingest/rate_limiter_test.go b/engine/collection/ingest/rate_limiter_test.go index 09779720439..c38b6af5182 100644 --- a/engine/collection/ingest/rate_limiter_test.go +++ b/engine/collection/ingest/rate_limiter_test.go @@ -101,8 +101,13 @@ func TestLimiterWaitLongEnough(t *testing.T) { return l.Allow(addr1) }, 110*time.Millisecond, 10*time.Millisecond) + // block again until another 100 ms + require.True(t, l.IsRateLimited(addr1)) + // block until another 100 ms - require.False(t, l.Allow(addr1)) + require.Eventually(t, func() bool { + return l.Allow(addr1) + }, 110*time.Millisecond, 10*time.Millisecond) } func TestLimiterConcurrentSafe(t *testing.T) { @@ -141,3 +146,40 @@ func TestLimiterConcurrentSafe(t *testing.T) { wg.Wait() require.Equal(t, uint64(1), succeed.Load()) // should only succeed once } + +func TestLimiterGetSetConfig(t *testing.T) { + t.Parallel() + + addr1 := unittest.RandomAddressFixture() + + // with limit set to 10, it means we allow 10 messages per second, + // and with burst set to 1, it means we only allow 1 message at a time, + // so the limit is 1 message per 100 milliseconds. + // Note rate.Limit(0.1) is not to set 1 message per 100 milliseconds, but + // 1 message per 10 seconds. + numPerSec := rate.Limit(10) + burst := 1 + l := ingest.NewAddressRateLimiter(numPerSec, burst) + + l.AddAddress(addr1) + require.False(t, l.IsRateLimited(addr1)) + require.True(t, l.IsRateLimited(addr1)) + + limitConfig, burstConfig := l.GetLimitConfig() + require.Equal(t, numPerSec, limitConfig) + require.Equal(t, burst, burstConfig) + + // change from 1 message per 100 ms to 4 messages per 200 ms + l.SetLimitConfig(rate.Limit(20), 4) + + // verify the quota is reset, and the new limit is applied + for i := 0; i < 4; i++ { + require.False(t, l.IsRateLimited(addr1), fmt.Sprintf("fail at %v-th call", i)) + } + require.True(t, l.IsRateLimited(addr1)) + + // check every 10 Millisecond then after 100 Millisecond it should be allowed + require.Eventually(t, func() bool { + return l.Allow(addr1) + }, 210*time.Millisecond, 10*time.Millisecond) +}