Skip to content

Commit

Permalink
Implemented token bucket rate limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
RostislavPorohnya committed Jul 25, 2024
1 parent 974fa12 commit b6238b5
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]

### Added
- Added Queries Rate Limiter which uses Token Bucket algorithm to the Session struct.
Added RateLimiterConfig to the ClusterConfig struct. (#1731)

### Changed

Expand Down
3 changes: 3 additions & 0 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,9 @@ type ClusterConfig struct {

// internal config for testing
disableControlConn bool

// If Session has RateLimiterConfig then queries will be limited using RateLimiter
RateLimiterConfig *RateLimiterConfig
}

type Dialer interface {
Expand Down
76 changes: 76 additions & 0 deletions rate_limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package gocql

import (
"sync"
"time"
)

// RateLimiterConfig holds the configuration parameters for the rate limiter, which uses Token Bucket approach.
//
// Fields:
//
// - rate: Allowed requests per second
// - Burst: Maximum number of burst requests
//
// Example:
// RateLimiterConfig{
// rate: 300000,
// burst: 150,
// }
type RateLimiterConfig struct {
rate float64
burst int
}

type tokenBucket struct {
rate float64
burst int
tokens int
lastRefilled time.Time
mu sync.Mutex
}

func (tb *tokenBucket) refill() {
tb.mu.Lock()
defer tb.mu.Unlock()
now := time.Now()
tokensToAdd := int(tb.rate * now.Sub(tb.lastRefilled).Seconds())
tb.tokens = min(tb.tokens+tokensToAdd, tb.burst)
tb.lastRefilled = now
}

func (tb *tokenBucket) Allow() bool {
tb.refill()
tb.mu.Lock()
defer tb.mu.Unlock()
if tb.tokens > 0 {
tb.tokens--
return true
}
return false
}

func min(a, b int) int {
if a < b {
return a
}
return b
}

type ConfigurableRateLimiter struct {
tb tokenBucket
}

func NewConfigurableRateLimiter(rate float64, burst int) *ConfigurableRateLimiter {
tb := tokenBucket{
rate: rate,
burst: burst,
tokens: burst,
lastRefilled: time.Now(),
}
return &ConfigurableRateLimiter{tb}
}

func (rl *ConfigurableRateLimiter) Allow() bool {
return rl.tb.Allow()
}
78 changes: 78 additions & 0 deletions rate_limiter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package gocql

import (
"fmt"
"sync"
"testing"
)

const queries = 100

const skipRateLimiterTestMsg = "Skipping rate limiter test, due to limit of simultaneously alive goroutines. Should be tested locally"

func TestRateLimiter50k(t *testing.T) {
t.Skip(skipRateLimiterTestMsg)
fmt.Println("Running rate limiter test with 50_000 workers")
RunRateLimiterTest(t, 50_000)
}

func TestRateLimiter100k(t *testing.T) {
t.Skip(skipRateLimiterTestMsg)
fmt.Println("Running rate limiter test with 100_000 workers")
RunRateLimiterTest(t, 100_000)
}

func TestRateLimiter200k(t *testing.T) {
t.Skip(skipRateLimiterTestMsg)
fmt.Println("Running rate limiter test with 200_000 workers")
RunRateLimiterTest(t, 200_000)
}

func RunRateLimiterTest(t *testing.T, workerCount int) {
cluster := createCluster()
cluster.RateLimiterConfig = &RateLimiterConfig{
rate: 300000,
burst: 100,
}

session := createSessionFromCluster(cluster, t)
defer session.Close()

execRelease(session.Query("drop keyspace if exists pargettest"))
execRelease(session.Query("create keyspace pargettest with replication = {'class' : 'SimpleStrategy', 'replication_factor' : 1}"))
execRelease(session.Query("drop table if exists pargettest.test"))
execRelease(session.Query("create table pargettest.test (a text, b int, primary key(a))"))
execRelease(session.Query("insert into pargettest.test (a, b) values ( 'a', 1)"))

var wg sync.WaitGroup

for i := 1; i <= workerCount; i++ {
wg.Add(1)

go func() {
defer wg.Done()
for j := 0; j < queries; j++ {
iterRelease(session.Query("select * from pargettest.test where a='a'"))
}
}()
}

wg.Wait()
}

func iterRelease(query *Query) {
_, err := query.Iter().SliceMap()
if err != nil {
println(err.Error())
panic(err)
}
query.Release()
}

func execRelease(query *Query) {
if err := query.Exec(); err != nil {
println(err.Error())
panic(err)
}
query.Release()
}
12 changes: 12 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ type Session struct {
isInitialized bool

logger StdLogger

rateLimiter *ConfigurableRateLimiter
}

var queryPool = &sync.Pool{
Expand Down Expand Up @@ -188,6 +190,10 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
s.frameObserver = cfg.FrameHeaderObserver
s.streamObserver = cfg.StreamObserver

if cfg.RateLimiterConfig != nil {
s.rateLimiter = NewConfigurableRateLimiter(cfg.RateLimiterConfig.rate, cfg.RateLimiterConfig.burst)
}

//Check the TLS Config before trying to connect to anything external
connCfg, err := connConfig(&s.cfg)
if err != nil {
Expand Down Expand Up @@ -452,6 +458,12 @@ func (s *Session) SetTrace(trace Tracer) {
// value before the query is executed. Query is automatically prepared
// if it has not previously been executed.
func (s *Session) Query(stmt string, values ...interface{}) *Query {
if s.rateLimiter != nil {
for !s.rateLimiter.Allow() {
time.Sleep(time.Millisecond * 50)
}
}

qry := queryPool.Get().(*Query)
qry.session = s
qry.stmt = stmt
Expand Down

0 comments on commit b6238b5

Please sign in to comment.