diff --git a/p2p/stream/common/streammanager/config.go b/p2p/stream/common/streammanager/config.go new file mode 100644 index 0000000000..671b174f36 --- /dev/null +++ b/p2p/stream/common/streammanager/config.go @@ -0,0 +1,25 @@ +package streammanager + +import "time" + +const ( + // checkInterval is the default interval for checking stream number. If the stream + // number is smaller than softLoCap, an active discover through DHT will be triggered. + checkInterval = 30 * time.Second + // discTimeout is the timeout for one batch of discovery + discTimeout = 10 * time.Second + // connectTimeout is the timeout for setting up a stream with a discovered peer + connectTimeout = 60 * time.Second +) + +// Config is the config for stream manager +type Config struct { + // HardLoCap is low cap of stream number that immediately trigger discovery + HardLoCap int + // SoftLoCap is low cap of stream number that will trigger discovery during stream check + SoftLoCap int + // HiCap is the high cap of stream number + HiCap int + // DiscBatch is the size of each discovery + DiscBatch int +} diff --git a/p2p/stream/common/streammanager/events.go b/p2p/stream/common/streammanager/events.go new file mode 100644 index 0000000000..c41fa9de2a --- /dev/null +++ b/p2p/stream/common/streammanager/events.go @@ -0,0 +1,28 @@ +package streammanager + +import ( + "github.com/ethereum/go-ethereum/event" + sttypes "github.com/harmony-one/harmony/p2p/stream/types" +) + +// EvtStreamAdded is the event of adding a new stream +type ( + EvtStreamAdded struct { + Stream sttypes.Stream + } + + // EvtStreamRemoved is an event of stream removed + EvtStreamRemoved struct { + ID sttypes.StreamID + } +) + +// SubscribeAddStreamEvent subscribe the add stream event +func (sm *streamManager) SubscribeAddStreamEvent(ch chan<- EvtStreamAdded) event.Subscription { + return sm.addStreamFeed.Subscribe(ch) +} + +// SubscribeRemoveStreamEvent subscribe the remove stream event +func (sm *streamManager) SubscribeRemoveStreamEvent(ch chan<- EvtStreamRemoved) event.Subscription { + return sm.removeStreamFeed.Subscribe(ch) +} diff --git a/p2p/stream/common/streammanager/events_test.go b/p2p/stream/common/streammanager/events_test.go new file mode 100644 index 0000000000..7f89f4fade --- /dev/null +++ b/p2p/stream/common/streammanager/events_test.go @@ -0,0 +1,73 @@ +package streammanager + +import ( + "sync/atomic" + "testing" + "time" +) + +func TestStreamManager_SubscribeAddStreamEvent(t *testing.T) { + sm := newTestStreamManager() + + addStreamEvtC := make(chan EvtStreamAdded, 1) + sub := sm.SubscribeAddStreamEvent(addStreamEvtC) + defer sub.Unsubscribe() + stopC := make(chan struct{}, 1) + + var numStreamAdded uint32 + go func() { + for { + select { + case <-addStreamEvtC: + atomic.AddUint32(&numStreamAdded, 1) + case <-stopC: + return + } + } + }() + + sm.Start() + time.Sleep(defTestWait) + close(stopC) + sm.Close() + + if atomic.LoadUint32(&numStreamAdded) != 16 { + t.Errorf("numStreamAdded unexpected") + } +} + +func TestStreamManager_SubscribeRemoveStreamEvent(t *testing.T) { + sm := newTestStreamManager() + + rmStreamEvtC := make(chan EvtStreamRemoved, 1) + sub := sm.SubscribeRemoveStreamEvent(rmStreamEvtC) + defer sub.Unsubscribe() + stopC := make(chan struct{}, 1) + + var numStreamRemoved uint32 + go func() { + for { + select { + case <-rmStreamEvtC: + atomic.AddUint32(&numStreamRemoved, 1) + case <-stopC: + return + } + } + }() + + sm.Start() + time.Sleep(defTestWait) + + err := sm.RemoveStream(makeStreamID(1)) + if err != nil { + t.Fatal(err) + } + time.Sleep(defTestWait) + close(stopC) + sm.Close() + + if atomic.LoadUint32(&numStreamRemoved) != 1 { + t.Errorf("numStreamAdded unexpected") + } +} diff --git a/p2p/stream/common/streammanager/interface.go b/p2p/stream/common/streammanager/interface.go new file mode 100644 index 0000000000..5c34488d86 --- /dev/null +++ b/p2p/stream/common/streammanager/interface.go @@ -0,0 +1,50 @@ +package streammanager + +import ( + "context" + + "github.com/ethereum/go-ethereum/event" + sttypes "github.com/harmony-one/harmony/p2p/stream/types" + p2ptypes "github.com/harmony-one/harmony/p2p/types" + "github.com/libp2p/go-libp2p-core/network" + libp2p_peer "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/protocol" +) + +// StreamManager is the interface for streamManager +type StreamManager interface { + p2ptypes.LifeCycle + StreamOperator + Subscriber + StreamReader +} + +// StreamOperator handles new stream or remove stream +type StreamOperator interface { + NewStream(stream sttypes.Stream) error + RemoveStream(stID sttypes.StreamID) error +} + +// Subscriber is the interface to support stream event subscription +type Subscriber interface { + SubscribeAddStreamEvent(ch chan<- EvtStreamAdded) event.Subscription + SubscribeRemoveStreamEvent(ch chan<- EvtStreamRemoved) event.Subscription +} + +// StreamReader is the interface to read stream in stream manager +type StreamReader interface { + GetStreams() []sttypes.Stream + GetStreamByID(id sttypes.StreamID) (sttypes.Stream, bool) +} + +// host is the adapter interface of the libp2p host implementation. +// TODO: further adapt the host +type host interface { + ID() libp2p_peer.ID + NewStream(ctx context.Context, p libp2p_peer.ID, pids ...protocol.ID) (network.Stream, error) +} + +// peerFinder is the adapter interface of discovery.Discovery +type peerFinder interface { + FindPeers(ctx context.Context, ns string, peerLimit int) (<-chan libp2p_peer.AddrInfo, error) +} diff --git a/p2p/stream/common/streammanager/interface_test.go b/p2p/stream/common/streammanager/interface_test.go new file mode 100644 index 0000000000..dc657050e6 --- /dev/null +++ b/p2p/stream/common/streammanager/interface_test.go @@ -0,0 +1,203 @@ +package streammanager + +import ( + "context" + "errors" + "strconv" + "sync" + "sync/atomic" + + sttypes "github.com/harmony-one/harmony/p2p/stream/types" + "github.com/libp2p/go-libp2p-core/network" + libp2p_peer "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/protocol" +) + +var _ StreamManager = &streamManager{} + +var ( + myPeerID = makePeerID(0) + testProtoID = sttypes.ProtoID("harmony/sync/unitest/0/1.0.0") +) + +const ( + defHardLoCap = 16 // discovery trigger immediately when size smaller than this number + defSoftLoCap = 32 // discovery trigger for routine check + defHiCap = 128 // Hard cap of the stream number + defDiscBatch = 16 // batch size for discovery +) + +var defConfig = Config{ + HardLoCap: defHardLoCap, + SoftLoCap: defSoftLoCap, + HiCap: defHiCap, + DiscBatch: defDiscBatch, +} + +func newTestStreamManager() *streamManager { + pid := testProtoID + host := newTestHost() + pf := newTestPeerFinder(makeRemotePeers(100), emptyDelayFunc) + + sm := newStreamManager(pid, host, pf, nil, defConfig) + host.sm = sm + return sm +} + +type testStream struct { + id sttypes.StreamID + proto sttypes.ProtoID + closed bool +} + +func newTestStream(id sttypes.StreamID, proto sttypes.ProtoID) *testStream { + return &testStream{id: id, proto: proto} +} + +func (st *testStream) ID() sttypes.StreamID { + return st.id +} + +func (st *testStream) ProtoID() sttypes.ProtoID { + return st.proto +} + +func (st *testStream) WriteBytes([]byte) error { + return nil +} + +func (st *testStream) ReadBytes() ([]byte, error) { + return nil, nil +} + +func (st *testStream) Close() error { + if st.closed { + return errors.New("already closed") + } + st.closed = true + return nil +} + +func (st *testStream) ResetOnClose() error { + if st.closed { + return errors.New("already closed") + } + st.closed = true + return nil +} + +func (st *testStream) ProtoSpec() (sttypes.ProtoSpec, error) { + return sttypes.ProtoIDToProtoSpec(st.ProtoID()) +} + +type testHost struct { + sm *streamManager + streams map[sttypes.StreamID]*testStream + lock sync.Mutex + + errHook streamErrorHook +} + +type streamErrorHook func(id sttypes.StreamID, err error) + +func newTestHost() *testHost { + return &testHost{ + streams: make(map[sttypes.StreamID]*testStream), + } +} + +func (h *testHost) ID() libp2p_peer.ID { + return myPeerID +} + +// NewStream mock the upper function logic. When stream setup and running protocol, the +// upper code logic will call StreamManager to add new stream +func (h *testHost) NewStream(ctx context.Context, p libp2p_peer.ID, pids ...protocol.ID) (network.Stream, error) { + if len(pids) == 0 { + return nil, errors.New("nil protocol ids") + } + var err error + stid := sttypes.StreamID(p) + defer func() { + if err != nil && h.errHook != nil { + h.errHook(stid, err) + } + }() + + st := newTestStream(stid, sttypes.ProtoID(pids[0])) + h.lock.Lock() + h.streams[stid] = st + h.lock.Unlock() + + err = h.sm.NewStream(st) + return nil, err +} + +func makeStreamID(index int) sttypes.StreamID { + return sttypes.StreamID(strconv.Itoa(index)) +} + +func makePeerID(index int) libp2p_peer.ID { + return libp2p_peer.ID(strconv.Itoa(index)) +} + +func makeRemotePeers(size int) []libp2p_peer.ID { + ids := make([]libp2p_peer.ID, 0, size) + for i := 1; i != size+1; i++ { + ids = append(ids, makePeerID(i)) + } + return ids +} + +type testPeerFinder struct { + peerIDs []libp2p_peer.ID + curIndex int32 + fpHook delayFunc +} + +type delayFunc func(id libp2p_peer.ID) <-chan struct{} + +func emptyDelayFunc(id libp2p_peer.ID) <-chan struct{} { + c := make(chan struct{}) + go func() { + c <- struct{}{} + }() + return c +} + +func newTestPeerFinder(ids []libp2p_peer.ID, fpHook delayFunc) *testPeerFinder { + return &testPeerFinder{ + peerIDs: ids, + curIndex: 0, + fpHook: fpHook, + } +} + +func (pf *testPeerFinder) FindPeers(ctx context.Context, ns string, peerLimit int) (<-chan libp2p_peer.AddrInfo, error) { + if peerLimit > len(pf.peerIDs) { + peerLimit = len(pf.peerIDs) + } + resC := make(chan libp2p_peer.AddrInfo) + + go func() { + defer close(resC) + + for i := 0; i != peerLimit; i++ { + // hack to prevent race + curIndex := atomic.LoadInt32(&pf.curIndex) + pid := pf.peerIDs[curIndex] + select { + case <-ctx.Done(): + return + case <-pf.fpHook(pid): + } + resC <- libp2p_peer.AddrInfo{ID: pid} + atomic.AddInt32(&pf.curIndex, 1) + if int(atomic.LoadInt32(&pf.curIndex)) == len(pf.peerIDs) { + pf.curIndex = 0 + } + } + }() + + return resC, nil +} diff --git a/p2p/stream/common/streammanager/streammanager.go b/p2p/stream/common/streammanager/streammanager.go new file mode 100644 index 0000000000..7e35f401bd --- /dev/null +++ b/p2p/stream/common/streammanager/streammanager.go @@ -0,0 +1,415 @@ +package streammanager + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/ethereum/go-ethereum/event" + "github.com/harmony-one/harmony/internal/utils" + sttypes "github.com/harmony-one/harmony/p2p/stream/types" + "github.com/libp2p/go-libp2p-core/network" + libp2p_peer "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/protocol" + "github.com/pkg/errors" + "github.com/rs/zerolog" +) + +var ( + // ErrStreamAlreadyRemoved is the error that a stream has already been removed + ErrStreamAlreadyRemoved = errors.New("stream already removed") +) + +// streamManager is the implementation of StreamManager. It manages streams on +// one single protocol. It does the following job: +// 1. add a new stream to manage with when a new stream starts running. +// 2. closes a stream when some unexpected error happens. +// 3. discover new streams when the number of streams is below threshold. +// 4. emit stream events for other modules. +type streamManager struct { + // streamManager only manages streams on one protocol. + myProtoID sttypes.ProtoID + myProtoSpec sttypes.ProtoSpec + config Config + // streams is the map of peer ID to stream + // Note that it could happen that remote node does not share exactly the same + // protocol ID (e.g. different version) + streams *streamSet + // libp2p utilities + host host + pf peerFinder + handleStream func(stream network.Stream) + // incoming task channels + addStreamCh chan addStreamTask + rmStreamCh chan rmStreamTask + stopCh chan stopTask + discCh chan discTask + curTask interface{} + // utils + addStreamFeed event.Feed + removeStreamFeed event.Feed + logger zerolog.Logger + ctx context.Context + cancel func() +} + +// NewStreamManager creates a new stream manager for the given proto ID +func NewStreamManager(pid sttypes.ProtoID, host host, pf peerFinder, handleStream func(network.Stream), c Config) StreamManager { + return newStreamManager(pid, host, pf, handleStream, c) +} + +// newStreamManager creates a new stream manager +func newStreamManager(pid sttypes.ProtoID, host host, pf peerFinder, handleStream func(network.Stream), c Config) *streamManager { + ctx, cancel := context.WithCancel(context.Background()) + + logger := utils.Logger().With().Str("module", "stream manager"). + Str("protocol ID", string(pid)).Logger() + + protoSpec, _ := sttypes.ProtoIDToProtoSpec(pid) + + return &streamManager{ + myProtoID: pid, + myProtoSpec: protoSpec, + config: c, + streams: newStreamSet(), + host: host, + pf: pf, + handleStream: handleStream, + addStreamCh: make(chan addStreamTask), + rmStreamCh: make(chan rmStreamTask), + stopCh: make(chan stopTask), + discCh: make(chan discTask, 1), // discCh is a buffered channel to avoid overuse of goroutine + + logger: logger, + ctx: ctx, + cancel: cancel, + } +} + +// Start starts the stream manager +func (sm *streamManager) Start() { + go sm.loop() +} + +// Close close the stream manager +func (sm *streamManager) Close() { + task := stopTask{done: make(chan struct{})} + sm.stopCh <- task + + <-task.done +} + +func (sm *streamManager) loop() { + var ( + discTicker = time.NewTicker(checkInterval) + discCtx context.Context + discCancel func() + ) + // bootstrap discovery + sm.discCh <- discTask{} + + for { + select { + case <-discTicker.C: + if !sm.softHaveEnoughStreams() { + sm.discCh <- discTask{} + } + + case <-sm.discCh: + // cancel last discovery + if discCancel != nil { + discCancel() + } + discCtx, discCancel = context.WithCancel(sm.ctx) + go func() { + err := sm.discoverAndSetupStream(discCtx) + if err != nil { + sm.logger.Err(err) + } + }() + + case addStream := <-sm.addStreamCh: + err := sm.handleAddStream(addStream.st) + addStream.errC <- err + + case rmStream := <-sm.rmStreamCh: + err := sm.handleRemoveStream(rmStream.id) + rmStream.errC <- err + + case stop := <-sm.stopCh: + sm.cancel() + sm.removeAllStreamOnClose() + stop.done <- struct{}{} + return + } + } +} + +// NewStream handles a new stream from stream handler protocol +func (sm *streamManager) NewStream(stream sttypes.Stream) error { + if err := sm.sanityCheckStream(stream); err != nil { + return errors.Wrap(err, "stream sanity check failed") + } + task := addStreamTask{ + st: stream, + errC: make(chan error), + } + sm.addStreamCh <- task + return <-task.errC +} + +// RemoveStream close and remove a stream from stream manager +func (sm *streamManager) RemoveStream(stID sttypes.StreamID) error { + task := rmStreamTask{ + id: stID, + errC: make(chan error), + } + sm.rmStreamCh <- task + return <-task.errC +} + +// GetStreams return the streams. +func (sm *streamManager) GetStreams() []sttypes.Stream { + return sm.streams.getStreams() +} + +// GetStreamByID return the stream with the given id. +func (sm *streamManager) GetStreamByID(id sttypes.StreamID) (sttypes.Stream, bool) { + return sm.streams.get(id) +} + +type ( + addStreamTask struct { + st sttypes.Stream + errC chan error + } + + rmStreamTask struct { + id sttypes.StreamID + errC chan error + } + + discTask struct{} + + stopTask struct { + done chan struct{} + } +) + +// sanity checks the service, network and shard ID +func (sm *streamManager) sanityCheckStream(st sttypes.Stream) error { + mySpec := sm.myProtoSpec + rmSpec, err := st.ProtoSpec() + if err != nil { + return err + } + if mySpec.Service != rmSpec.Service { + return fmt.Errorf("unexpected service: %v/%v", rmSpec.Service, mySpec.Service) + } + if mySpec.NetworkType != rmSpec.NetworkType { + return fmt.Errorf("unexpected network: %v/%v", rmSpec.NetworkType, mySpec.NetworkType) + } + if mySpec.ShardID != rmSpec.ShardID { + return fmt.Errorf("unexpected shard ID: %v/%v", rmSpec.ShardID, mySpec.ShardID) + } + return nil +} + +func (sm *streamManager) handleAddStream(st sttypes.Stream) error { + id := st.ID() + if sm.streams.size() >= sm.config.HiCap { + return errors.New("too many streams") + } + if _, ok := sm.streams.get(id); ok { + return errors.New("stream already exist") + } + + sm.streams.addStream(st) + + sm.addStreamFeed.Send(EvtStreamAdded{st}) + return nil +} + +func (sm *streamManager) handleRemoveStream(id sttypes.StreamID) error { + st, ok := sm.streams.get(id) + if !ok { + return ErrStreamAlreadyRemoved + } + + sm.streams.deleteStream(st) + // if stream number is smaller than HardLoCap, spin up the discover + if !sm.hardHaveEnoughStream() { + select { + case sm.discCh <- discTask{}: + default: + } + } + sm.removeStreamFeed.Send(EvtStreamRemoved{id}) + return nil +} + +func (sm *streamManager) removeAllStreamOnClose() { + var wg sync.WaitGroup + + for _, st := range sm.streams.slice() { + wg.Add(1) + go func(st sttypes.Stream) { + defer wg.Done() + err := st.ResetOnClose() + if err != nil { + sm.logger.Warn().Err(err).Str("stream ID", string(st.ID())). + Msg("failed to close stream") + } + }(st) + } + wg.Wait() + + // Be nice. after close, the field is still accessible to prevent potential panics + sm.streams = newStreamSet() +} + +func (sm *streamManager) discoverAndSetupStream(discCtx context.Context) error { + peers, err := sm.discover(discCtx) + if err != nil { + return errors.Wrap(err, "failed to discover") + } + for peer := range peers { + if peer.ID == sm.host.ID() { + continue + } + go func(pid libp2p_peer.ID) { + // The ctx here is using the module context instead of discover context + err := sm.setupStreamWithPeer(sm.ctx, pid) + if err != nil { + sm.logger.Warn().Err(err).Str("peerID", string(pid)).Msg("failed to setup stream with peer") + return + } + }(peer.ID) + } + return nil +} + +func (sm *streamManager) discover(ctx context.Context) (<-chan libp2p_peer.AddrInfo, error) { + protoID := string(sm.myProtoID) + discBatch := sm.config.DiscBatch + if sm.config.HiCap-sm.streams.size() < sm.config.DiscBatch { + discBatch = sm.config.HiCap - sm.streams.size() + } + if discBatch < 0 { + return nil, nil + } + + ctx, _ = context.WithTimeout(ctx, discTimeout) + return sm.pf.FindPeers(ctx, protoID, discBatch) +} + +func (sm *streamManager) setupStreamWithPeer(ctx context.Context, pid libp2p_peer.ID) error { + nCtx, cancel := context.WithTimeout(ctx, connectTimeout) + defer cancel() + + st, err := sm.host.NewStream(nCtx, pid, protocol.ID(sm.myProtoID)) + if err != nil { + return err + } + if sm.handleStream != nil { + go sm.handleStream(st) + } + return nil +} + +func (sm *streamManager) softHaveEnoughStreams() bool { + availStreams := sm.streams.numStreamsWithMinProtoSpec(sm.myProtoSpec) + return availStreams >= sm.config.SoftLoCap +} + +func (sm *streamManager) hardHaveEnoughStream() bool { + availStreams := sm.streams.numStreamsWithMinProtoSpec(sm.myProtoSpec) + return availStreams >= sm.config.HardLoCap +} + +// streamSet is the concurrency safe stream set. +type streamSet struct { + streams map[sttypes.StreamID]sttypes.Stream + numByProto map[sttypes.ProtoSpec]int + lock sync.RWMutex +} + +func newStreamSet() *streamSet { + return &streamSet{ + streams: make(map[sttypes.StreamID]sttypes.Stream), + numByProto: make(map[sttypes.ProtoSpec]int), + } +} + +func (ss *streamSet) size() int { + ss.lock.RLock() + defer ss.lock.RUnlock() + + return len(ss.streams) +} + +func (ss *streamSet) get(id sttypes.StreamID) (sttypes.Stream, bool) { + ss.lock.RLock() + defer ss.lock.RUnlock() + + st, ok := ss.streams[id] + return st, ok +} + +func (ss *streamSet) addStream(st sttypes.Stream) { + ss.lock.Lock() + defer ss.lock.Unlock() + + ss.streams[st.ID()] = st + spec, _ := st.ProtoSpec() + ss.numByProto[spec]++ +} + +func (ss *streamSet) deleteStream(st sttypes.Stream) { + ss.lock.Lock() + defer ss.lock.Unlock() + + delete(ss.streams, st.ID()) + + spec, _ := st.ProtoSpec() + ss.numByProto[spec]-- + if ss.numByProto[spec] == 0 { + delete(ss.numByProto, spec) + } +} + +func (ss *streamSet) slice() []sttypes.Stream { + ss.lock.RLock() + defer ss.lock.RUnlock() + + sts := make([]sttypes.Stream, 0, len(ss.streams)) + for _, st := range ss.streams { + sts = append(sts, st) + } + return sts +} + +func (ss *streamSet) getStreams() []sttypes.Stream { + ss.lock.RLock() + defer ss.lock.RUnlock() + + res := make([]sttypes.Stream, 0, len(ss.streams)) + for _, st := range ss.streams { + res = append(res, st) + } + return res +} + +func (ss *streamSet) numStreamsWithMinProtoSpec(minSpec sttypes.ProtoSpec) int { + ss.lock.RLock() + defer ss.lock.RUnlock() + + var res int + for spec, num := range ss.numByProto { + if !spec.Version.LessThan(minSpec.Version) { + res += num + } + } + return res +} diff --git a/p2p/stream/common/streammanager/streammanager_test.go b/p2p/stream/common/streammanager/streammanager_test.go new file mode 100644 index 0000000000..2d7ded5180 --- /dev/null +++ b/p2p/stream/common/streammanager/streammanager_test.go @@ -0,0 +1,243 @@ +package streammanager + +import ( + "errors" + "fmt" + "strings" + "sync" + "testing" + "time" + + sttypes "github.com/harmony-one/harmony/p2p/stream/types" + libp2p_peer "github.com/libp2p/go-libp2p-core/peer" +) + +const ( + defTestWait = 100 * time.Millisecond +) + +// When started, discover will be run at bootstrap +func TestStreamManager_BootstrapDisc(t *testing.T) { + sm := newTestStreamManager() + sm.host.(*testHost).errHook = func(id sttypes.StreamID, err error) { + t.Errorf("%s stream error: %v", id, err) + } + + // After bootstrap, stream manager shall discover streams and setup connection + // Note host will mock the upper code logic to call sm.NewStream in this case + sm.Start() + time.Sleep(defTestWait) + if gotSize := sm.streams.size(); gotSize != sm.config.DiscBatch { + t.Errorf("unexpected stream size: %v != %v", gotSize, sm.config.DiscBatch) + } +} + +// After close, all stream shall be closed and removed +func TestStreamManager_Close(t *testing.T) { + sm := newTestStreamManager() + // Bootstrap + sm.Start() + time.Sleep(defTestWait) + // Close stream manager, all stream shall be closed and removed + closeDone := make(chan struct{}) + go func() { + sm.Close() + closeDone <- struct{}{} + }() + select { + case <-time.After(defTestWait): + t.Errorf("still not closed") + case <-closeDone: + } + // Check stream been removed from stream manager and all streams to be closed + if sm.streams.size() != 0 { + t.Errorf("after close, stream not removed from stream manager") + } + host := sm.host.(*testHost) + for _, st := range host.streams { + if !st.closed { + t.Errorf("after close, stream still not closed") + } + } +} + +// Close shall terminate the current discover at once +func TestStreamManager_CloseDisc(t *testing.T) { + sm := newTestStreamManager() + // discover will be blocked forever + sm.pf.(*testPeerFinder).fpHook = func(id libp2p_peer.ID) <-chan struct{} { + select {} + } + sm.Start() + time.Sleep(defTestWait) + // Close stream manager, all stream shall be closed and removed + closeDone := make(chan struct{}) + go func() { + sm.Close() + closeDone <- struct{}{} + }() + select { + case <-time.After(defTestWait): + t.Errorf("close shall unblock the current discovery") + case <-closeDone: + } +} + +// Each time discTicker ticks, it will cancel last discovery, and start a new one +func TestStreamManager_refreshDisc(t *testing.T) { + sm := newTestStreamManager() + // discover will be blocked for the first time but good for second time + var once sync.Once + sm.pf.(*testPeerFinder).fpHook = func(id libp2p_peer.ID) <-chan struct{} { + var sendSig = true + once.Do(func() { + sendSig = false + }) + c := make(chan struct{}, 1) + if sendSig { + c <- struct{}{} + } + return c + } + sm.Start() + time.Sleep(defTestWait) + + sm.discCh <- struct{}{} + time.Sleep(defTestWait) + + // We shall now have non-zero streams setup + if sm.streams.size() == 0 { + t.Errorf("stream size still zero after refresh") + } +} + +func TestStreamManager_HandleNewStream(t *testing.T) { + tests := []struct { + stream sttypes.Stream + expSize int + expErr error + }{ + { + stream: newTestStream(makeStreamID(100), testProtoID), + expSize: DefDiscBatch + 1, + expErr: nil, + }, + { + stream: newTestStream(makeStreamID(1), testProtoID), + expSize: DefDiscBatch, + expErr: errors.New("stream already exist"), + }, + } + for i, test := range tests { + sm := newTestStreamManager() + sm.Start() + time.Sleep(defTestWait) + + err := sm.NewStream(test.stream) + if assErr := assertError(err, test.expErr); assErr != nil { + t.Errorf("Test %v: %v", i, assErr) + } + + if sm.streams.size() != test.expSize { + t.Errorf("Test %v: unexpected stream size: %v / %v", i, sm.streams.size(), + test.expSize) + } + } +} + +func TestStreamManager_HandleRemoveStream(t *testing.T) { + tests := []struct { + id sttypes.StreamID + expSize int + expErr error + }{ + { + id: makeStreamID(1), + expSize: DefDiscBatch - 1, + expErr: nil, + }, + { + id: makeStreamID(100), + expSize: DefDiscBatch, + expErr: errors.New("stream already removed"), + }, + } + for i, test := range tests { + sm := newTestStreamManager() + sm.Start() + time.Sleep(defTestWait) + + err := sm.RemoveStream(test.id) + if assErr := assertError(err, test.expErr); assErr != nil { + t.Errorf("Test %v: %v", i, assErr) + } + + if sm.streams.size() != test.expSize { + t.Errorf("Test %v: unexpected stream size: %v / %v", i, sm.streams.size(), + test.expSize) + } + } +} + +// When number of streams is smaller than hard low limit, discover will be triggered +func TestStreamManager_HandleRemoveStream_Disc(t *testing.T) { + sm := newTestStreamManager() + sm.Start() + time.Sleep(defTestWait) + + // Remove DiscBatch - HardLoCap + 1 streams + num := 0 + for _, st := range sm.streams.slice() { + if err := sm.RemoveStream(st.ID()); err != nil { + t.Error(err) + } + num++ + if num == sm.config.DiscBatch-sm.config.HardLoCap+1 { + break + } + } + + // Last remove stream will also trigger discover + time.Sleep(defTestWait) + if sm.streams.size() != sm.config.HardLoCap+sm.config.DiscBatch-1 { + t.Errorf("unexpected stream number %v / %v", sm.streams.size(), sm.config.HardLoCap+sm.config.DiscBatch-1) + } +} + +func TestStreamSet_numStreamsWithMinProtoID(t *testing.T) { + var ( + pid1 = testProtoID + numPid1 = 5 + + pid2 = sttypes.ProtoID("harmony/sync/unitest/0/1.0.1") + numPid2 = 10 + ) + + ss := newStreamSet() + + for i := 0; i != numPid1; i++ { + ss.addStream(newTestStream(makeStreamID(i), pid1)) + } + for i := 0; i != numPid2; i++ { + ss.addStream(newTestStream(makeStreamID(i), pid2)) + } + + minSpec, _ := sttypes.ProtoIDToProtoSpec(pid2) + num := ss.numStreamsWithMinProtoSpec(minSpec) + if num != numPid2 { + t.Errorf("unexpected result: %v/%v", num, numPid2) + } +} + +func assertError(got, exp error) error { + if (got == nil) != (exp == nil) { + return fmt.Errorf("unexpected error: %v / %v", got, exp) + } + if got == nil { + return nil + } + if !strings.Contains(got.Error(), exp.Error()) { + return fmt.Errorf("unexpected error: %v / %v", got, exp) + } + return nil +}