Skip to content

Commit

Permalink
feat(dot/network): implement streamManager to cleanup not recently us…
Browse files Browse the repository at this point in the history
…ed streams (ChainSafe#1611)
  • Loading branch information
noot authored Jun 2, 2021
1 parent dd3838c commit ba861bf
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 11 deletions.
19 changes: 13 additions & 6 deletions dot/network/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,13 @@ type Service struct {
ctx context.Context
cancel context.CancelFunc

cfg *Config
host *host
mdns *mdns
gossip *gossip
syncQueue *syncQueue
bufPool *sizedBufferPool
cfg *Config
host *host
mdns *mdns
gossip *gossip
syncQueue *syncQueue
bufPool *sizedBufferPool
streamManager *streamManager

notificationsProtocols map[byte]*notificationsProtocol // map of sub-protocol msg ID to protocol info
notificationsMu sync.RWMutex
Expand Down Expand Up @@ -162,6 +163,7 @@ func NewService(cfg *Config) (*Service, error) {
telemetryInterval: cfg.telemetryInterval,
closeCh: make(chan interface{}),
bufPool: bufPool,
streamManager: newStreamManager(ctx),
}

network.syncQueue = newSyncQueue(network)
Expand Down Expand Up @@ -267,6 +269,7 @@ func (s *Service) Start() error {
go s.logPeerCount()
go s.publishNetworkTelemetry(s.closeCh)
go s.sentBlockIntervalTelemetry()
s.streamManager.start()

return nil
}
Expand Down Expand Up @@ -529,6 +532,8 @@ func isInbound(stream libp2pnetwork.Stream) bool {
}

func (s *Service) readStream(stream libp2pnetwork.Stream, decoder messageDecoder, handler messageHandler) {
s.streamManager.logNewStream(stream)

peer := stream.Conn().RemotePeer()
msgBytes := s.bufPool.get()
defer s.bufPool.put(&msgBytes)
Expand All @@ -543,6 +548,8 @@ func (s *Service) readStream(stream libp2pnetwork.Stream, decoder messageDecoder
return
}

s.streamManager.logMessageReceived(stream.ID())

// decode message based on message type
msg, err := decoder(msgBytes[:tot], peer, isInbound(stream))
if err != nil {
Expand Down
81 changes: 81 additions & 0 deletions dot/network/stream_manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package network

import (
"context"
"sync"
"time"

"github.com/libp2p/go-libp2p-core/network"
)

var cleanupStreamInterval = time.Minute

type streamData struct {
lastReceivedMessage time.Time
stream network.Stream
}

// streamManager tracks inbound streams and runs a cleanup goroutine every `cleanupStreamInterval` to close streams that
// we haven't received any data on for the last time period. this prevents keeping stale streams open and continuously trying to
// read from it, which takes up lots of CPU over time.
type streamManager struct {
ctx context.Context
streamDataMap *sync.Map //map[string]*streamData
}

func newStreamManager(ctx context.Context) *streamManager {
return &streamManager{
ctx: ctx,
streamDataMap: new(sync.Map),
}
}

func (sm *streamManager) start() {
go func() {
ticker := time.NewTicker(cleanupStreamInterval)
defer ticker.Stop()

for {
select {
case <-sm.ctx.Done():
return
case <-ticker.C:
sm.cleanupStreams()
}
}
}()
}

func (sm *streamManager) cleanupStreams() {
sm.streamDataMap.Range(func(id, data interface{}) bool {
sdata := data.(*streamData)
lastReceived := sdata.lastReceivedMessage
stream := sdata.stream

if time.Since(lastReceived) > cleanupStreamInterval {
_ = stream.Close()
sm.streamDataMap.Delete(id)
}

return true
})
}

func (sm *streamManager) logNewStream(stream network.Stream) {
data := &streamData{
lastReceivedMessage: time.Now(), // prevents closing just opened streams, in case the cleanup goroutine runs at the same time stream is opened
stream: stream,
}
sm.streamDataMap.Store(stream.ID(), data)
}

func (sm *streamManager) logMessageReceived(streamID string) {
data, has := sm.streamDataMap.Load(streamID)
if !has {
return
}

sdata := data.(*streamData)
sdata.lastReceivedMessage = time.Now()
sm.streamDataMap.Store(streamID, sdata)
}
102 changes: 102 additions & 0 deletions dot/network/stream_manager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package network

import (
"context"
"fmt"
"testing"
"time"

"github.com/libp2p/go-libp2p"
libp2phost "github.com/libp2p/go-libp2p-core/host"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
ma "github.com/multiformats/go-multiaddr"

"github.com/stretchr/testify/require"
)

func setupStreamManagerTest(t *testing.T) (context.Context, []libp2phost.Host, []*streamManager) {
ctx, cancel := context.WithCancel(context.Background())

cleanupStreamInterval = time.Millisecond * 500
t.Cleanup(func() {
cleanupStreamInterval = time.Minute
cancel()
})

smA := newStreamManager(ctx)
smB := newStreamManager(ctx)

portA := 7001
portB := 7002
addrA, err := ma.NewMultiaddr(fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", portA))
require.NoError(t, err)
addrB, err := ma.NewMultiaddr(fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", portB))
require.NoError(t, err)

ha, err := libp2p.New(
ctx, libp2p.ListenAddrs(addrA),
)
require.NoError(t, err)

hb, err := libp2p.New(
ctx, libp2p.ListenAddrs(addrB),
)
require.NoError(t, err)

err = ha.Connect(ctx, peer.AddrInfo{
ID: hb.ID(),
Addrs: hb.Addrs(),
})
require.NoError(t, err)

hb.SetStreamHandler("", func(stream network.Stream) {
smB.logNewStream(stream)
})

return ctx, []libp2phost.Host{ha, hb}, []*streamManager{smA, smB}
}

func TestStreamManager(t *testing.T) {
ctx, hosts, sms := setupStreamManagerTest(t)
ha, hb := hosts[0], hosts[1]
smA, smB := sms[0], sms[1]

stream, err := ha.NewStream(ctx, hb.ID(), "")
require.NoError(t, err)

smA.logNewStream(stream)
smA.start()
smB.start()

time.Sleep(cleanupStreamInterval * 2)
connsAToB := ha.Network().ConnsToPeer(hb.ID())
require.Equal(t, 1, len(connsAToB))
require.Equal(t, 0, len(connsAToB[0].GetStreams()))

connsBToA := hb.Network().ConnsToPeer(ha.ID())
require.Equal(t, 1, len(connsBToA))
require.Equal(t, 0, len(connsBToA[0].GetStreams()))
}

func TestStreamManager_KeepStream(t *testing.T) {
ctx, hosts, sms := setupStreamManagerTest(t)
ha, hb := hosts[0], hosts[1]
smA, smB := sms[0], sms[1]

stream, err := ha.NewStream(ctx, hb.ID(), "")
require.NoError(t, err)

smA.logNewStream(stream)
smA.start()
smB.start()

time.Sleep(cleanupStreamInterval / 2)
connsAToB := ha.Network().ConnsToPeer(hb.ID())
require.Equal(t, 1, len(connsAToB))
require.Equal(t, 1, len(connsAToB[0].GetStreams()))

connsBToA := hb.Network().ConnsToPeer(ha.ID())
require.Equal(t, 1, len(connsBToA))
require.Equal(t, 1, len(connsBToA[0].GetStreams()))
}
6 changes: 1 addition & 5 deletions dot/network/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,7 @@ func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) {
)

length, err := readLEB128ToUint64(stream, buf[:1])
if err == io.EOF {
return 0, err
} else if err != nil {
if err != nil {
return 0, err // TODO: return bytes read from readLEB128ToUint64
}

Expand All @@ -196,13 +194,11 @@ func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) {

if length > uint64(len(buf)) {
logger.Warn("received message with size greater than allocated message buffer", "length", length, "buffer size", len(buf))
_ = stream.Close()
return 0, fmt.Errorf("message size greater than allocated message buffer: got %d", length)
}

if length > maxBlockResponseSize {
logger.Warn("received message with size greater than maxBlockResponseSize, closing stream", "length", length)
_ = stream.Close()
return 0, fmt.Errorf("message size greater than maximum: got %d", length)
}

Expand Down

0 comments on commit ba861bf

Please sign in to comment.