diff --git a/network/interface.go b/network/interface.go index 14b170239..0b594a2e5 100644 --- a/network/interface.go +++ b/network/interface.go @@ -26,12 +26,18 @@ func (t TopicID) String() string { type EventType int const ( - EventTypeGossip EventType = 1 - EventTypeStream EventType = 2 + EventTypeConnect EventType = 1 + EventTypeDisconnect EventType = 2 + EventTypeGossip EventType = 3 + EventTypeStream EventType = 4 ) func (t EventType) String() string { switch t { + case EventTypeConnect: + return "connect" + case EventTypeDisconnect: + return "disconnect" case EventTypeGossip: return "gossip-msg" case EventTypeStream: @@ -68,6 +74,24 @@ func (*StreamMessage) Type() EventType { return EventTypeStream } +// ConnectEvent represents a peer connection event. +type ConnectEvent struct { + PeerID lp2pcore.PeerID +} + +func (*ConnectEvent) Type() EventType { + return EventTypeConnect +} + +// DisconnectEvent represents a peer disconnection event. +type DisconnectEvent struct { + PeerID lp2pcore.PeerID +} + +func (*DisconnectEvent) Type() EventType { + return EventTypeDisconnect +} + type Network interface { Start() error Stop() diff --git a/network/mdns_test.go b/network/mdns_test.go index 510aeceaa..fa67d4dfb 100644 --- a/network/mdns_test.go +++ b/network/mdns_test.go @@ -26,7 +26,9 @@ func TestMDNS(t *testing.T) { msg := []byte("test-mdns") assert.NoError(t, net1.SendTo(msg, net2.SelfID())) - e := shouldReceiveEvent(t, net2).(*StreamMessage) - assert.Equal(t, e.Source, net1.SelfID()) - assert.Equal(t, readData(t, e.Reader, len(msg)), msg) + e1 := shouldReceiveEvent(t, net2, EventTypeConnect).(*ConnectEvent) + assert.Equal(t, e1.PeerID, net1.SelfID()) + e2 := shouldReceiveEvent(t, net2, EventTypeStream).(*StreamMessage) + assert.Equal(t, e2.Source, net1.SelfID()) + assert.Equal(t, readData(t, e2.Reader, len(msg)), msg) } diff --git a/network/network.go b/network/network.go index e3cb114e3..9e9cca045 100644 --- a/network/network.go +++ b/network/network.go @@ -36,6 +36,7 @@ type network struct { dht *dhtService stream *streamService gossip *gossipService + notifee *NotifeeService generalTopic *lp2pps.Topic consensusTopic *lp2pps.Topic eventChannel chan Event @@ -171,6 +172,9 @@ func newNetwork(conf *Config, opts []lp2p.Option) (*network, error) { n.dht = newDHTService(n.ctx, n.host, kadProtocolID, conf.Bootstrap, n.logger) n.stream = newStreamService(ctx, n.host, streamProtocolID, relayAddrs, n.eventChannel, n.logger) n.gossip = newGossipService(ctx, n.host, n.eventChannel, n.logger) + n.notifee = newNotifeeService(n.host, n.eventChannel, n.logger) + + n.host.Network().Notify(n.notifee) n.logger.Info("network setup", "id", n.host.ID(), "address", conf.Listens) diff --git a/network/network_test.go b/network/network_test.go index 5780a1dee..e648db69d 100644 --- a/network/network_test.go +++ b/network/network_test.go @@ -66,7 +66,7 @@ func testConfig() *Config { } } -func shouldReceiveEvent(t *testing.T, net *network) Event { +func shouldReceiveEvent(t *testing.T, net *network, eventType EventType) Event { timeout := time.NewTimer(2 * time.Second) for { @@ -75,7 +75,9 @@ func shouldReceiveEvent(t *testing.T, net *network) Event { require.NoError(t, fmt.Errorf("shouldReceiveEvent Timeout, test: %v id:%s", t.Name(), net.SelfID().String())) return nil case e := <-net.EventChannel(): - return e + if e.Type() == eventType { + return e + } } } } @@ -225,10 +227,10 @@ func TestNetwork(t *testing.T) { require.NoError(t, networkP.Broadcast(msg, TopicIDGeneral)) - eB := shouldReceiveEvent(t, networkB).(*GossipMessage) - eM := shouldReceiveEvent(t, networkM).(*GossipMessage) - eN := shouldReceiveEvent(t, networkN).(*GossipMessage) - eX := shouldReceiveEvent(t, networkX).(*GossipMessage) + eB := shouldReceiveEvent(t, networkB, EventTypeGossip).(*GossipMessage) + eM := shouldReceiveEvent(t, networkM, EventTypeGossip).(*GossipMessage) + eN := shouldReceiveEvent(t, networkN, EventTypeGossip).(*GossipMessage) + eX := shouldReceiveEvent(t, networkX, EventTypeGossip).(*GossipMessage) assert.Equal(t, eB.Source, networkP.SelfID()) assert.Equal(t, eM.Source, networkP.SelfID()) @@ -246,9 +248,9 @@ func TestNetwork(t *testing.T) { require.NoError(t, networkP.Broadcast(msg, TopicIDConsensus)) - eB := shouldReceiveEvent(t, networkB).(*GossipMessage) - eM := shouldReceiveEvent(t, networkM).(*GossipMessage) - eN := shouldReceiveEvent(t, networkN).(*GossipMessage) + eB := shouldReceiveEvent(t, networkB, EventTypeGossip).(*GossipMessage) + eM := shouldReceiveEvent(t, networkM, EventTypeGossip).(*GossipMessage) + eN := shouldReceiveEvent(t, networkN, EventTypeGossip).(*GossipMessage) shouldNotReceiveEvent(t, networkX) assert.Equal(t, eB.Source, networkP.SelfID()) @@ -270,7 +272,7 @@ func TestNetwork(t *testing.T) { msgB := []byte("test-stream-from-b") require.NoError(t, networkB.SendTo(msgB, networkP.SelfID())) - eB := shouldReceiveEvent(t, networkP).(*StreamMessage) + eB := shouldReceiveEvent(t, networkP, EventTypeStream).(*StreamMessage) assert.Equal(t, eB.Source, networkB.SelfID()) assert.Equal(t, readData(t, eB.Reader, len(msgB)), msgB) }) @@ -279,7 +281,7 @@ func TestNetwork(t *testing.T) { msgM := []byte("test-stream-from-m") require.NoError(t, networkM.SendTo(msgM, networkN.SelfID())) - eM := shouldReceiveEvent(t, networkN).(*StreamMessage) + eM := shouldReceiveEvent(t, networkN, EventTypeStream).(*StreamMessage) assert.Equal(t, eM.Source, networkM.SelfID()) assert.Equal(t, readData(t, eM.Reader, len(msgM)), msgM) }) @@ -295,7 +297,8 @@ func TestNetwork(t *testing.T) { networkP.Stop() networkB.CloseConnection(networkP.SelfID()) - + e := shouldReceiveEvent(t, networkB, EventTypeDisconnect).(*DisconnectEvent) + assert.Equal(t, e.PeerID, networkP.SelfID()) require.Error(t, networkB.SendTo(msgB, networkP.SelfID())) }) } diff --git a/network/notifee.go b/network/notifee.go new file mode 100644 index 000000000..8be15acbd --- /dev/null +++ b/network/notifee.go @@ -0,0 +1,41 @@ +package network + +import ( + lp2phost "github.com/libp2p/go-libp2p/core/host" + lp2pnetwork "github.com/libp2p/go-libp2p/core/network" + ma "github.com/multiformats/go-multiaddr" + "github.com/pactus-project/pactus/util/logger" +) + +type NotifeeService struct { + eventChannel chan<- Event + logger *logger.SubLogger +} + +func newNotifeeService(host lp2phost.Host, eventChannel chan<- Event, logger *logger.SubLogger) *NotifeeService { + notifee := &NotifeeService{ + eventChannel: eventChannel, + logger: logger, + } + host.Network().Notify(notifee) + return notifee +} + +func (n *NotifeeService) Connected(_ lp2pnetwork.Network, conn lp2pnetwork.Conn) { + n.eventChannel <- &ConnectEvent{PeerID: conn.RemotePeer()} +} + +func (n *NotifeeService) Disconnected(_ lp2pnetwork.Network, conn lp2pnetwork.Conn) { + n.eventChannel <- &DisconnectEvent{PeerID: conn.RemotePeer()} +} + +func (n *NotifeeService) Listen(_ lp2pnetwork.Network, ma ma.Multiaddr) { + // Handle listen event if needed. + n.logger.Debug("Notifee Listen event emitted", "Multiaddr", ma.String()) +} + +// ListenClose is called when your node stops listening on an address. +func (n *NotifeeService) ListenClose(_ lp2pnetwork.Network, ma ma.Multiaddr) { + // Handle listen close event if needed. + n.logger.Debug("Notifee ListenClose event emitted", "Multiaddr", ma.String()) +}