-
Notifications
You must be signed in to change notification settings - Fork 129
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(dot/network): implement streamManager to cleanup not recently us…
…ed streams (#1611)
- Loading branch information
Showing
4 changed files
with
197 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
package network | ||
|
||
import ( | ||
"context" | ||
"sync" | ||
"time" | ||
|
||
"github.com/libp2p/go-libp2p-core/network" | ||
) | ||
|
||
var cleanupStreamInterval = time.Minute | ||
|
||
type streamData struct { | ||
lastReceivedMessage time.Time | ||
stream network.Stream | ||
} | ||
|
||
// streamManager tracks inbound streams and runs a cleanup goroutine every `cleanupStreamInterval` to close streams that | ||
// we haven't received any data on for the last time period. this prevents keeping stale streams open and continuously trying to | ||
// read from it, which takes up lots of CPU over time. | ||
type streamManager struct { | ||
ctx context.Context | ||
streamDataMap *sync.Map //map[string]*streamData | ||
} | ||
|
||
func newStreamManager(ctx context.Context) *streamManager { | ||
return &streamManager{ | ||
ctx: ctx, | ||
streamDataMap: new(sync.Map), | ||
} | ||
} | ||
|
||
func (sm *streamManager) start() { | ||
go func() { | ||
ticker := time.NewTicker(cleanupStreamInterval) | ||
defer ticker.Stop() | ||
|
||
for { | ||
select { | ||
case <-sm.ctx.Done(): | ||
return | ||
case <-ticker.C: | ||
sm.cleanupStreams() | ||
} | ||
} | ||
}() | ||
} | ||
|
||
func (sm *streamManager) cleanupStreams() { | ||
sm.streamDataMap.Range(func(id, data interface{}) bool { | ||
sdata := data.(*streamData) | ||
lastReceived := sdata.lastReceivedMessage | ||
stream := sdata.stream | ||
|
||
if time.Since(lastReceived) > cleanupStreamInterval { | ||
_ = stream.Close() | ||
sm.streamDataMap.Delete(id) | ||
} | ||
|
||
return true | ||
}) | ||
} | ||
|
||
func (sm *streamManager) logNewStream(stream network.Stream) { | ||
data := &streamData{ | ||
lastReceivedMessage: time.Now(), // prevents closing just opened streams, in case the cleanup goroutine runs at the same time stream is opened | ||
stream: stream, | ||
} | ||
sm.streamDataMap.Store(stream.ID(), data) | ||
} | ||
|
||
func (sm *streamManager) logMessageReceived(streamID string) { | ||
data, has := sm.streamDataMap.Load(streamID) | ||
if !has { | ||
return | ||
} | ||
|
||
sdata := data.(*streamData) | ||
sdata.lastReceivedMessage = time.Now() | ||
sm.streamDataMap.Store(streamID, sdata) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
package network | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"testing" | ||
"time" | ||
|
||
"github.com/libp2p/go-libp2p" | ||
libp2phost "github.com/libp2p/go-libp2p-core/host" | ||
"github.com/libp2p/go-libp2p-core/network" | ||
"github.com/libp2p/go-libp2p-core/peer" | ||
ma "github.com/multiformats/go-multiaddr" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func setupStreamManagerTest(t *testing.T) (context.Context, []libp2phost.Host, []*streamManager) { | ||
ctx, cancel := context.WithCancel(context.Background()) | ||
|
||
cleanupStreamInterval = time.Millisecond * 500 | ||
t.Cleanup(func() { | ||
cleanupStreamInterval = time.Minute | ||
cancel() | ||
}) | ||
|
||
smA := newStreamManager(ctx) | ||
smB := newStreamManager(ctx) | ||
|
||
portA := 7001 | ||
portB := 7002 | ||
addrA, err := ma.NewMultiaddr(fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", portA)) | ||
require.NoError(t, err) | ||
addrB, err := ma.NewMultiaddr(fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", portB)) | ||
require.NoError(t, err) | ||
|
||
ha, err := libp2p.New( | ||
ctx, libp2p.ListenAddrs(addrA), | ||
) | ||
require.NoError(t, err) | ||
|
||
hb, err := libp2p.New( | ||
ctx, libp2p.ListenAddrs(addrB), | ||
) | ||
require.NoError(t, err) | ||
|
||
err = ha.Connect(ctx, peer.AddrInfo{ | ||
ID: hb.ID(), | ||
Addrs: hb.Addrs(), | ||
}) | ||
require.NoError(t, err) | ||
|
||
hb.SetStreamHandler("", func(stream network.Stream) { | ||
smB.logNewStream(stream) | ||
}) | ||
|
||
return ctx, []libp2phost.Host{ha, hb}, []*streamManager{smA, smB} | ||
} | ||
|
||
func TestStreamManager(t *testing.T) { | ||
ctx, hosts, sms := setupStreamManagerTest(t) | ||
ha, hb := hosts[0], hosts[1] | ||
smA, smB := sms[0], sms[1] | ||
|
||
stream, err := ha.NewStream(ctx, hb.ID(), "") | ||
require.NoError(t, err) | ||
|
||
smA.logNewStream(stream) | ||
smA.start() | ||
smB.start() | ||
|
||
time.Sleep(cleanupStreamInterval * 2) | ||
connsAToB := ha.Network().ConnsToPeer(hb.ID()) | ||
require.Equal(t, 1, len(connsAToB)) | ||
require.Equal(t, 0, len(connsAToB[0].GetStreams())) | ||
|
||
connsBToA := hb.Network().ConnsToPeer(ha.ID()) | ||
require.Equal(t, 1, len(connsBToA)) | ||
require.Equal(t, 0, len(connsBToA[0].GetStreams())) | ||
} | ||
|
||
func TestStreamManager_KeepStream(t *testing.T) { | ||
ctx, hosts, sms := setupStreamManagerTest(t) | ||
ha, hb := hosts[0], hosts[1] | ||
smA, smB := sms[0], sms[1] | ||
|
||
stream, err := ha.NewStream(ctx, hb.ID(), "") | ||
require.NoError(t, err) | ||
|
||
smA.logNewStream(stream) | ||
smA.start() | ||
smB.start() | ||
|
||
time.Sleep(cleanupStreamInterval / 2) | ||
connsAToB := ha.Network().ConnsToPeer(hb.ID()) | ||
require.Equal(t, 1, len(connsAToB)) | ||
require.Equal(t, 1, len(connsAToB[0].GetStreams())) | ||
|
||
connsBToA := hb.Network().ConnsToPeer(ha.ID()) | ||
require.Equal(t, 1, len(connsBToA)) | ||
require.Equal(t, 1, len(connsBToA[0].GetStreams())) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters