diff --git a/dot/network/service.go b/dot/network/service.go index effb2b7506..32d23cc852 100644 --- a/dot/network/service.go +++ b/dot/network/service.go @@ -67,12 +67,13 @@ type Service struct { ctx context.Context cancel context.CancelFunc - cfg *Config - host *host - mdns *mdns - gossip *gossip - syncQueue *syncQueue - bufPool *sizedBufferPool + cfg *Config + host *host + mdns *mdns + gossip *gossip + syncQueue *syncQueue + bufPool *sizedBufferPool + streamManager *streamManager notificationsProtocols map[byte]*notificationsProtocol // map of sub-protocol msg ID to protocol info notificationsMu sync.RWMutex @@ -162,6 +163,7 @@ func NewService(cfg *Config) (*Service, error) { telemetryInterval: cfg.telemetryInterval, closeCh: make(chan interface{}), bufPool: bufPool, + streamManager: newStreamManager(ctx), } network.syncQueue = newSyncQueue(network) @@ -267,6 +269,7 @@ func (s *Service) Start() error { go s.logPeerCount() go s.publishNetworkTelemetry(s.closeCh) go s.sentBlockIntervalTelemetry() + s.streamManager.start() return nil } @@ -529,6 +532,8 @@ func isInbound(stream libp2pnetwork.Stream) bool { } func (s *Service) readStream(stream libp2pnetwork.Stream, decoder messageDecoder, handler messageHandler) { + s.streamManager.logNewStream(stream) + peer := stream.Conn().RemotePeer() msgBytes := s.bufPool.get() defer s.bufPool.put(&msgBytes) @@ -543,6 +548,8 @@ func (s *Service) readStream(stream libp2pnetwork.Stream, decoder messageDecoder return } + s.streamManager.logMessageReceived(stream.ID()) + // decode message based on message type msg, err := decoder(msgBytes[:tot], peer, isInbound(stream)) if err != nil { diff --git a/dot/network/stream_manager.go b/dot/network/stream_manager.go new file mode 100644 index 0000000000..6755f5c3da --- /dev/null +++ b/dot/network/stream_manager.go @@ -0,0 +1,81 @@ +package network + +import ( + "context" + "sync" + "time" + + "github.com/libp2p/go-libp2p-core/network" +) + +var cleanupStreamInterval = time.Minute + +type streamData struct { + lastReceivedMessage time.Time + stream network.Stream +} + +// streamManager tracks inbound streams and runs a cleanup goroutine every `cleanupStreamInterval` to close streams that +// we haven't received any data on for the last time period. this prevents keeping stale streams open and continuously trying to +// read from it, which takes up lots of CPU over time. +type streamManager struct { + ctx context.Context + streamDataMap *sync.Map //map[string]*streamData +} + +func newStreamManager(ctx context.Context) *streamManager { + return &streamManager{ + ctx: ctx, + streamDataMap: new(sync.Map), + } +} + +func (sm *streamManager) start() { + go func() { + ticker := time.NewTicker(cleanupStreamInterval) + defer ticker.Stop() + + for { + select { + case <-sm.ctx.Done(): + return + case <-ticker.C: + sm.cleanupStreams() + } + } + }() +} + +func (sm *streamManager) cleanupStreams() { + sm.streamDataMap.Range(func(id, data interface{}) bool { + sdata := data.(*streamData) + lastReceived := sdata.lastReceivedMessage + stream := sdata.stream + + if time.Since(lastReceived) > cleanupStreamInterval { + _ = stream.Close() + sm.streamDataMap.Delete(id) + } + + return true + }) +} + +func (sm *streamManager) logNewStream(stream network.Stream) { + data := &streamData{ + lastReceivedMessage: time.Now(), // prevents closing just opened streams, in case the cleanup goroutine runs at the same time stream is opened + stream: stream, + } + sm.streamDataMap.Store(stream.ID(), data) +} + +func (sm *streamManager) logMessageReceived(streamID string) { + data, has := sm.streamDataMap.Load(streamID) + if !has { + return + } + + sdata := data.(*streamData) + sdata.lastReceivedMessage = time.Now() + sm.streamDataMap.Store(streamID, sdata) +} diff --git a/dot/network/stream_manager_test.go b/dot/network/stream_manager_test.go new file mode 100644 index 0000000000..c831f01144 --- /dev/null +++ b/dot/network/stream_manager_test.go @@ -0,0 +1,102 @@ +package network + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/libp2p/go-libp2p" + libp2phost "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + ma "github.com/multiformats/go-multiaddr" + + "github.com/stretchr/testify/require" +) + +func setupStreamManagerTest(t *testing.T) (context.Context, []libp2phost.Host, []*streamManager) { + ctx, cancel := context.WithCancel(context.Background()) + + cleanupStreamInterval = time.Millisecond * 500 + t.Cleanup(func() { + cleanupStreamInterval = time.Minute + cancel() + }) + + smA := newStreamManager(ctx) + smB := newStreamManager(ctx) + + portA := 7001 + portB := 7002 + addrA, err := ma.NewMultiaddr(fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", portA)) + require.NoError(t, err) + addrB, err := ma.NewMultiaddr(fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", portB)) + require.NoError(t, err) + + ha, err := libp2p.New( + ctx, libp2p.ListenAddrs(addrA), + ) + require.NoError(t, err) + + hb, err := libp2p.New( + ctx, libp2p.ListenAddrs(addrB), + ) + require.NoError(t, err) + + err = ha.Connect(ctx, peer.AddrInfo{ + ID: hb.ID(), + Addrs: hb.Addrs(), + }) + require.NoError(t, err) + + hb.SetStreamHandler("", func(stream network.Stream) { + smB.logNewStream(stream) + }) + + return ctx, []libp2phost.Host{ha, hb}, []*streamManager{smA, smB} +} + +func TestStreamManager(t *testing.T) { + ctx, hosts, sms := setupStreamManagerTest(t) + ha, hb := hosts[0], hosts[1] + smA, smB := sms[0], sms[1] + + stream, err := ha.NewStream(ctx, hb.ID(), "") + require.NoError(t, err) + + smA.logNewStream(stream) + smA.start() + smB.start() + + time.Sleep(cleanupStreamInterval * 2) + connsAToB := ha.Network().ConnsToPeer(hb.ID()) + require.Equal(t, 1, len(connsAToB)) + require.Equal(t, 0, len(connsAToB[0].GetStreams())) + + connsBToA := hb.Network().ConnsToPeer(ha.ID()) + require.Equal(t, 1, len(connsBToA)) + require.Equal(t, 0, len(connsBToA[0].GetStreams())) +} + +func TestStreamManager_KeepStream(t *testing.T) { + ctx, hosts, sms := setupStreamManagerTest(t) + ha, hb := hosts[0], hosts[1] + smA, smB := sms[0], sms[1] + + stream, err := ha.NewStream(ctx, hb.ID(), "") + require.NoError(t, err) + + smA.logNewStream(stream) + smA.start() + smB.start() + + time.Sleep(cleanupStreamInterval / 2) + connsAToB := ha.Network().ConnsToPeer(hb.ID()) + require.Equal(t, 1, len(connsAToB)) + require.Equal(t, 1, len(connsAToB[0].GetStreams())) + + connsBToA := hb.Network().ConnsToPeer(ha.ID()) + require.Equal(t, 1, len(connsBToA)) + require.Equal(t, 1, len(connsBToA[0].GetStreams())) +} diff --git a/dot/network/utils.go b/dot/network/utils.go index 771cabda08..211f247a8f 100644 --- a/dot/network/utils.go +++ b/dot/network/utils.go @@ -184,9 +184,7 @@ func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) { ) length, err := readLEB128ToUint64(stream, buf[:1]) - if err == io.EOF { - return 0, err - } else if err != nil { + if err != nil { return 0, err // TODO: return bytes read from readLEB128ToUint64 } @@ -196,13 +194,11 @@ func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) { if length > uint64(len(buf)) { logger.Warn("received message with size greater than allocated message buffer", "length", length, "buffer size", len(buf)) - _ = stream.Close() return 0, fmt.Errorf("message size greater than allocated message buffer: got %d", length) } if length > maxBlockResponseSize { logger.Warn("received message with size greater than maxBlockResponseSize, closing stream", "length", length) - _ = stream.Close() return 0, fmt.Errorf("message size greater than maximum: got %d", length) }