Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: apply rate limit for the network topics #1332

Merged
merged 3 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ func TestExampleConfig(t *testing.T) {
if !(strings.HasPrefix(line, "# ") ||
strings.HasPrefix(line, "###") ||
strings.HasPrefix(line, " # ") ||
strings.HasPrefix(line, " # ")) {
strings.HasPrefix(line, " # ") ||
strings.HasPrefix(line, " # ")) {
exampleToml += line
exampleToml += "\n"
}
Expand Down
14 changes: 14 additions & 0 deletions config/example_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,20 @@
# Any incoming and outgoing connections to banned addresses will be terminated.
banned_nets = []

# `rate_limit` contains the rate limit configurations for network topics.
# The rate limit specifies the number of messages per second that are allowed.
# If set to zero, it allows all requests without any limit.`
[sync.firewall.rate_limit]

# `block_topic` specifies the rate limit for the block topic.
block_topic = 0

# `transaction_topic` specifies the rate limit for the transaction topic.
transaction_topic = 5

# `consensus_topic` specifies the rate limit for the consensus topic.
consensus_topic = 0

# `tx_pool` contains configuration options for the transaction pool module.
[tx_pool]

Expand Down
2 changes: 2 additions & 0 deletions network/gossip.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ func (g *gossipService) joinTopic(topicID TopicID, sp ShouldPropagate) (*lp2pps.
TopicID: topicID,
}
if !sp(msg) {
g.logger.Debug("message ignored", "from", peerId, "topic", topicID)

// Consume the message first
g.onReceiveMessage(m)

Expand Down
2 changes: 1 addition & 1 deletion sync/bundle/message/hello_ack.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ func (*HelloAckMessage) Type() Type {
}

func (m *HelloAckMessage) String() string {
return fmt.Sprintf("{%s: %s}", m.ResponseCode, m.Reason)
return fmt.Sprintf("{%s: %s %v}", m.ResponseCode, m.Reason, m.Height)
}
14 changes: 13 additions & 1 deletion sync/firewall/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,25 @@ import (
"net"
)

type RateLimit struct {
BlockTopic int `toml:"block_topic"`
TransactionTopic int `toml:"transaction_topic"`
ConsensusTopic int `toml:"consensus_topic"`
}

type Config struct {
BannedNets []string `toml:"banned_nets"`
BannedNets []string `toml:"banned_nets"`
RateLimit RateLimit `toml:"rate_limit"`
}

func DefaultConfig() *Config {
return &Config{
BannedNets: make([]string, 0),
RateLimit: RateLimit{
BlockTopic: 0,
TransactionTopic: 5,
ConsensusTopic: 0,
},
}
}

Expand Down
48 changes: 36 additions & 12 deletions sync/firewall/firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package firewall
import (
"bytes"
"io"
"time"

"github.com/multiformats/go-multiaddr"
"github.com/pactus-project/pactus/genesis"
Expand All @@ -14,16 +15,20 @@ import (
"github.com/pactus-project/pactus/util/errors"
"github.com/pactus-project/pactus/util/ipblocker"
"github.com/pactus-project/pactus/util/logger"
"github.com/pactus-project/pactus/util/ratelimit"
)

// Firewall check packets before passing them to sync module.
type Firewall struct {
config *Config
network network.Network
peerSet *peerset.PeerSet
state state.Facade
ipBlocker *ipblocker.IPBlocker
logger *logger.SubLogger
config *Config
network network.Network
peerSet *peerset.PeerSet
state state.Facade
ipBlocker *ipblocker.IPBlocker
blockRateLimit *ratelimit.RateLimit
transactionRateLimit *ratelimit.RateLimit
consensusRateLimit *ratelimit.RateLimit
logger *logger.SubLogger
}

func NewFirewall(conf *Config, net network.Network, peerSet *peerset.PeerSet, st state.Facade,
Expand All @@ -34,13 +39,20 @@ func NewFirewall(conf *Config, net network.Network, peerSet *peerset.PeerSet, st
return nil, err
}

blockRateLimit := ratelimit.NewRateLimit(conf.RateLimit.BlockTopic, time.Second)
transactionRateLimit := ratelimit.NewRateLimit(conf.RateLimit.TransactionTopic, time.Second)
consensusRateLimit := ratelimit.NewRateLimit(conf.RateLimit.ConsensusTopic, time.Second)

return &Firewall{
config: conf,
network: net,
peerSet: peerSet,
state: st,
ipBlocker: blocker,
logger: log,
config: conf,
network: net,
peerSet: peerSet,
state: st,
ipBlocker: blocker,
blockRateLimit: blockRateLimit,
transactionRateLimit: transactionRateLimit,
consensusRateLimit: consensusRateLimit,
logger: log,
}, nil
}

Expand Down Expand Up @@ -184,3 +196,15 @@ func (*Firewall) getIPFromMultiAddress(address string) (string, error) {

return ip, nil
}

func (f *Firewall) AllowBlockRequest() bool {
return f.blockRateLimit.AllowRequest()
}

func (f *Firewall) AllowTransactionRequest() bool {
return f.transactionRateLimit.AllowRequest()
}

func (f *Firewall) AllowConsensusRequest() bool {
return f.consensusRateLimit.AllowRequest()
}
30 changes: 30 additions & 0 deletions sync/firewall/firewall_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,33 @@ func TestParseP2PAddr(t *testing.T) {
})
}
}

func TestAllowBlockRequest(t *testing.T) {
conf := DefaultConfig()
conf.RateLimit.BlockTopic = 1

td := setup(t, conf)

assert.True(t, td.firewall.AllowBlockRequest())
assert.False(t, td.firewall.AllowBlockRequest())
}

func TestAllowTransactionRequest(t *testing.T) {
conf := DefaultConfig()
conf.RateLimit.TransactionTopic = 1

td := setup(t, conf)

assert.True(t, td.firewall.AllowTransactionRequest())
assert.False(t, td.firewall.AllowTransactionRequest())
}

func TestAllowConsensusRequest(t *testing.T) {
conf := DefaultConfig()
conf.RateLimit.ConsensusTopic = 1

td := setup(t, conf)

assert.True(t, td.firewall.AllowConsensusRequest())
assert.False(t, td.firewall.AllowConsensusRequest())
}
16 changes: 10 additions & 6 deletions sync/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ func NewSynchronizer(
}

func (sync *synchronizer) Start() error {
if err := sync.network.JoinTopic(network.TopicIDBlock, sync.shouldPropagateGeneralMessage); err != nil {
if err := sync.network.JoinTopic(network.TopicIDBlock, sync.shouldPropagateBlockMessage); err != nil {
return err
}
if err := sync.network.JoinTopic(network.TopicIDTransaction, sync.shouldPropagateGeneralMessage); err != nil {
if err := sync.network.JoinTopic(network.TopicIDTransaction, sync.shouldPropagateTransactionMessage); err != nil {
return err
}
// TODO: Not joining consensus topic when we are syncing
Expand Down Expand Up @@ -604,10 +604,14 @@ func (sync *synchronizer) prepareBlocks(from, count uint32) [][]byte {
return blocks
}

func (*synchronizer) shouldPropagateGeneralMessage(_ *network.GossipMessage) bool {
return true
func (sync *synchronizer) shouldPropagateBlockMessage(_ *network.GossipMessage) bool {
return sync.firewall.AllowBlockRequest()
}

func (*synchronizer) shouldPropagateConsensusMessage(_ *network.GossipMessage) bool {
return true
func (sync *synchronizer) shouldPropagateTransactionMessage(_ *network.GossipMessage) bool {
return sync.firewall.AllowTransactionRequest()
}

func (sync *synchronizer) shouldPropagateConsensusMessage(_ *network.GossipMessage) bool {
return sync.firewall.AllowConsensusRequest()
}
56 changes: 56 additions & 0 deletions util/ratelimit/ratelimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package ratelimit

import (
"sync"
"time"
)

type RateLimit struct {
lk sync.RWMutex

referenceTime time.Time
threshold int
counter int
window time.Duration
}

// NewRateLimit initializes a new RateLimit instance with the given threshold and window duration.
func NewRateLimit(threshold int, window time.Duration) *RateLimit {
return &RateLimit{
referenceTime: time.Now(),
threshold: threshold,
counter: 0,
window: window,
}
}

func (r *RateLimit) diff() time.Duration {
return time.Since(r.referenceTime)
}

func (r *RateLimit) reset() {
r.counter = 0
r.referenceTime = time.Now()
}

// AllowRequest increments the counter and checks if the rate limit is exceeded.
// If the threshold is zero, it allows all requests.
func (r *RateLimit) AllowRequest() bool {
r.lk.Lock()
defer r.lk.Unlock()

// If the threshold is zero, allow all requests
if r.threshold == 0 {
return true
}

// Check if the window has expired and reset if necessary
if r.diff() > r.window {
r.reset()
}

r.counter++

// Check if the threshold is exceeded
return r.counter <= r.threshold
}
54 changes: 54 additions & 0 deletions util/ratelimit/ratelimit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package ratelimit

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestRateLimit(t *testing.T) {
threshold := 5
window := 100 * time.Millisecond
r := NewRateLimit(threshold, window)

t.Run("InitialState", func(t *testing.T) {
assert.Equal(t, 0, r.counter)
})

t.Run("AllowRequestWithinThreshold", func(t *testing.T) {
for i := 0; i < threshold; i++ {
assert.True(t, r.AllowRequest())
}
assert.Equal(t, threshold, r.counter)
})

t.Run("ExceedThreshold", func(t *testing.T) {
assert.False(t, r.AllowRequest())
})

t.Run("ResetAfterWindow", func(t *testing.T) {
time.Sleep(window + 10*time.Millisecond)
assert.True(t, r.AllowRequest())
assert.Equal(t, 1, r.counter)
})

t.Run("ResetMethod", func(t *testing.T) {
r.reset()
assert.Equal(t, 0, r.counter)
assert.True(t, r.AllowRequest())
assert.Equal(t, 1, r.counter)
})

t.Run("DiffMethod", func(t *testing.T) {
assert.LessOrEqual(t, r.diff(), window)
})
}

func TestRateLimitZeroThreshold(t *testing.T) {
window := 100 * time.Millisecond
r := NewRateLimit(0, window)

assert.True(t, r.AllowRequest())
assert.Zero(t, r.counter)
}
Loading