diff --git a/dot/network/block_announce.go b/dot/network/block_announce.go index 34371ffdf4..1ccc47116a 100644 --- a/dot/network/block_announce.go +++ b/dot/network/block_announce.go @@ -45,12 +45,12 @@ type BlockAnnounceMessage struct { } // SubProtocol returns the block-announces sub-protocol -func (bm *BlockAnnounceMessage) SubProtocol() string { +func (*BlockAnnounceMessage) SubProtocol() string { return blockAnnounceID } // Type returns BlockAnnounceMsgType -func (bm *BlockAnnounceMessage) Type() byte { +func (*BlockAnnounceMessage) Type() byte { return BlockAnnounceMsgType } @@ -105,7 +105,7 @@ func (bm *BlockAnnounceMessage) Hash() common.Hash { } // IsHandshake returns false -func (bm *BlockAnnounceMessage) IsHandshake() bool { +func (*BlockAnnounceMessage) IsHandshake() bool { return false } @@ -137,7 +137,7 @@ type BlockAnnounceHandshake struct { } // SubProtocol returns the block-announces sub-protocol -func (hs *BlockAnnounceHandshake) SubProtocol() string { +func (*BlockAnnounceHandshake) SubProtocol() string { return blockAnnounceID } @@ -170,17 +170,17 @@ func (hs *BlockAnnounceHandshake) Decode(in []byte) error { } // Type ... -func (hs *BlockAnnounceHandshake) Type() byte { +func (*BlockAnnounceHandshake) Type() byte { return 0 } // Hash ... -func (hs *BlockAnnounceHandshake) Hash() common.Hash { +func (*BlockAnnounceHandshake) Hash() common.Hash { return common.Hash{} } // IsHandshake returns true -func (hs *BlockAnnounceHandshake) IsHandshake() bool { +func (*BlockAnnounceHandshake) IsHandshake() bool { return true } @@ -220,7 +220,7 @@ func (s *Service) validateBlockAnnounceHandshake(peer peer.ID, hs Handshake) err // don't need to lock here, since function is always called inside the func returned by // `createNotificationsMessageHandler` which locks the map beforehand. - data, ok := np.getHandshakeData(peer, true) + data, ok := np.getInboundHandshakeData(peer) if ok { data.handshake = hs // TODO: since this is used only for rpc system_peers only, diff --git a/dot/network/gossip_test.go b/dot/network/gossip_test.go index 7cd645ae72..160236acf5 100644 --- a/dot/network/gossip_test.go +++ b/dot/network/gossip_test.go @@ -42,7 +42,7 @@ func TestGossip(t *testing.T) { nodeA := createTestService(t, configA) handlerA := newTestStreamHandler(testBlockAnnounceMessageDecoder) - nodeA.host.registerStreamHandler("", handlerA.handleStream) + nodeA.host.registerStreamHandler(nodeA.host.protocolID, handlerA.handleStream) basePathB := utils.NewTestBasePath(t, "nodeB") configB := &Config{ @@ -54,7 +54,7 @@ func TestGossip(t *testing.T) { nodeB := createTestService(t, configB) handlerB := newTestStreamHandler(testBlockAnnounceMessageDecoder) - nodeB.host.registerStreamHandler("", handlerB.handleStream) + nodeB.host.registerStreamHandler(nodeB.host.protocolID, handlerB.handleStream) addrInfoA := nodeA.host.addrInfo() err := nodeB.host.connect(addrInfoA) @@ -75,7 +75,7 @@ func TestGossip(t *testing.T) { nodeC := createTestService(t, configC) handlerC := newTestStreamHandler(testBlockAnnounceMessageDecoder) - nodeC.host.registerStreamHandler("", handlerC.handleStream) + nodeC.host.registerStreamHandler(nodeC.host.protocolID, handlerC.handleStream) err = nodeC.host.connect(addrInfoA) // retry connect if "failed to dial" error diff --git a/dot/network/host.go b/dot/network/host.go index 6ef34855f7..ab24f78b8e 100644 --- a/dot/network/host.go +++ b/dot/network/host.go @@ -224,19 +224,9 @@ func (h *host) close() error { return nil } -// registerStreamHandler registers the stream handler, appending the given sub-protocol to the main protocol ID -func (h *host) registerStreamHandler(sub protocol.ID, handler func(libp2pnetwork.Stream)) { - h.h.SetStreamHandler(h.protocolID+sub, handler) -} - -// registerStreamHandlerWithOverwrite registers the stream handler. if overwrite is true, it uses the passed protocol ID -// for the handler, otherwise it appends the given sub-protocol to the main protocol ID -func (h *host) registerStreamHandlerWithOverwrite(pid protocol.ID, overwrite bool, handler func(libp2pnetwork.Stream)) { - if overwrite { - h.h.SetStreamHandler(pid, handler) - } else { - h.h.SetStreamHandler(h.protocolID+pid, handler) - } +// registerStreamHandler registers the stream handler for the given protocol id. +func (h *host) registerStreamHandler(pid protocol.ID, handler func(libp2pnetwork.Stream)) { + h.h.SetStreamHandler(pid, handler) } // connect connects the host to a specific peer address @@ -251,8 +241,10 @@ func (h *host) connect(p peer.AddrInfo) (err error) { // bootstrap connects the host to the configured bootnodes func (h *host) bootstrap() { failed := 0 - all := append(h.bootnodes, h.persistentPeers...) - for _, addrInfo := range all { + var allNodes []peer.AddrInfo + allNodes = append(allNodes, h.bootnodes...) + allNodes = append(allNodes, h.persistentPeers...) + for _, addrInfo := range allNodes { logger.Debug("bootstrapping to peer", "peer", addrInfo.ID) err := h.connect(addrInfo) if err != nil { @@ -260,7 +252,7 @@ func (h *host) bootstrap() { failed++ } } - if failed == len(all) && len(all) != 0 { + if failed == len(allNodes) && len(allNodes) != 0 { logger.Error("failed to bootstrap to any bootnode") } } diff --git a/dot/network/host_test.go b/dot/network/host_test.go index 918075937c..c3e6f0cd36 100644 --- a/dot/network/host_test.go +++ b/dot/network/host_test.go @@ -188,7 +188,7 @@ func TestSend(t *testing.T) { nodeB := createTestService(t, configB) nodeB.noGossip = true handler := newTestStreamHandler(testBlockRequestMessageDecoder) - nodeB.host.registerStreamHandler("", handler.handleStream) + nodeB.host.registerStreamHandler(nodeB.host.protocolID, handler.handleStream) addrInfoB := nodeB.host.addrInfo() err := nodeA.host.connect(addrInfoB) @@ -223,7 +223,7 @@ func TestExistingStream(t *testing.T) { nodeA := createTestService(t, configA) nodeA.noGossip = true handlerA := newTestStreamHandler(testBlockRequestMessageDecoder) - nodeA.host.registerStreamHandler("", handlerA.handleStream) + nodeA.host.registerStreamHandler(nodeA.host.protocolID, handlerA.handleStream) addrInfoA := nodeA.host.addrInfo() basePathB := utils.NewTestBasePath(t, "nodeB") @@ -237,7 +237,7 @@ func TestExistingStream(t *testing.T) { nodeB := createTestService(t, configB) nodeB.noGossip = true handlerB := newTestStreamHandler(testBlockRequestMessageDecoder) - nodeB.host.registerStreamHandler("", handlerB.handleStream) + nodeB.host.registerStreamHandler(nodeB.host.protocolID, handlerB.handleStream) addrInfoB := nodeB.host.addrInfo() err := nodeA.host.connect(addrInfoB) @@ -329,7 +329,7 @@ func TestStreamCloseMetadataCleanup(t *testing.T) { }) // Verify that handshake data exists. - _, ok := info.getHandshakeData(nodeB.host.id(), true) + _, ok := info.getInboundHandshakeData(nodeB.host.id()) require.True(t, ok) time.Sleep(time.Second) @@ -339,7 +339,7 @@ func TestStreamCloseMetadataCleanup(t *testing.T) { time.Sleep(time.Second) // Verify that handshake data is cleared. - _, ok = info.getHandshakeData(nodeB.host.id(), true) + _, ok = info.getInboundHandshakeData(nodeB.host.id()) require.False(t, ok) } @@ -501,7 +501,7 @@ func TestStreamCloseEOF(t *testing.T) { nodeB := createTestService(t, configB) nodeB.noGossip = true handler := newTestStreamHandler(testBlockRequestMessageDecoder) - nodeB.host.registerStreamHandler("", handler.handleStream) + nodeB.host.registerStreamHandler(nodeB.host.protocolID, handler.handleStream) require.False(t, handler.exit) addrInfoB := nodeB.host.addrInfo() diff --git a/dot/network/notifications.go b/dot/network/notifications.go index 91f0a21245..2c62fcff63 100644 --- a/dot/network/notifications.go +++ b/dot/network/notifications.go @@ -18,18 +18,20 @@ package network import ( "errors" + "reflect" "sync" "time" - "unsafe" libp2pnetwork "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/protocol" ) -var errCannotValidateHandshake = errors.New("failed to validate handshake") +var ( + errCannotValidateHandshake = errors.New("failed to validate handshake") + maxHandshakeSize = reflect.TypeOf(BlockAnnounceHandshake{}).Size() +) -const maxHandshakeSize = unsafe.Sizeof(BlockAnnounceHandshake{}) //nolint const handshakeTimeout = time.Second * 10 // Handshake is the interface all handshakes for notifications protocols must implement @@ -70,18 +72,27 @@ type notificationsProtocol struct { outboundHandshakeData *sync.Map //map[peer.ID]*handshakeData } -func (n *notificationsProtocol) getHandshakeData(pid peer.ID, inbound bool) (handshakeData, bool) { +func (n *notificationsProtocol) getInboundHandshakeData(pid peer.ID) (handshakeData, bool) { var ( data interface{} has bool ) - if inbound { - data, has = n.inboundHandshakeData.Load(pid) - } else { - data, has = n.outboundHandshakeData.Load(pid) + data, has = n.inboundHandshakeData.Load(pid) + if !has { + return handshakeData{}, false } + return data.(handshakeData), true +} + +func (n *notificationsProtocol) getOutboundHandshakeData(pid peer.ID) (handshakeData, bool) { + var ( + data interface{} + has bool + ) + + data, has = n.outboundHandshakeData.Load(pid) if !has { return handshakeData{}, false } @@ -110,7 +121,18 @@ func createDecoder(info *notificationsProtocol, handshakeDecoder HandshakeDecode return func(in []byte, peer peer.ID, inbound bool) (Message, error) { // if we don't have handshake data on this peer, or we haven't received the handshake from them already, // assume we are receiving the handshake - if hsData, has := info.getHandshakeData(peer, inbound); !has || !hsData.received { + var ( + hsData handshakeData + has bool + ) + + if inbound { + hsData, has = info.getInboundHandshakeData(peer) + } else { + hsData, has = info.getOutboundHandshakeData(peer) + } + + if !has || !hsData.received { return handshakeDecoder(in) } @@ -150,7 +172,7 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, // note: if this function is being called, it's being called via SetStreamHandler, // ie it is an inbound stream and we only send the handshake over it. // we do not send any other data over this stream, we would need to open a new outbound stream. - if _, has := info.getHandshakeData(peer, true); !has { + if _, has := info.getInboundHandshakeData(peer); !has { logger.Trace("receiver: validating handshake", "protocol", info.protocolID) hsData := newHandshakeData(true, false, stream) @@ -211,7 +233,7 @@ func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtoc return } - hsData, has := info.getHandshakeData(peer, false) + hsData, has := info.getOutboundHandshakeData(peer) if has && !hsData.validated { // peer has sent us an invalid handshake in the past, ignore return diff --git a/dot/network/notifications_test.go b/dot/network/notifications_test.go index 059a773b1d..b71f1a6faa 100644 --- a/dot/network/notifications_test.go +++ b/dot/network/notifications_test.go @@ -19,9 +19,11 @@ package network import ( "fmt" "math/big" + "reflect" "sync" "testing" "time" + "unsafe" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" @@ -89,7 +91,7 @@ func TestCreateDecoder_BlockAnnounce(t *testing.T) { require.NoError(t, err) // set handshake data to received - hsData, _ := info.getHandshakeData(testPeerID, true) + hsData, _ := info.getInboundHandshakeData(testPeerID) hsData.received = true info.inboundHandshakeData.Store(testPeerID, hsData) msg, err = decoder(enc, testPeerID, true) @@ -210,7 +212,7 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T) err = handler(stream, testHandshake) require.Equal(t, errCannotValidateHandshake, err) - data, has := info.getHandshakeData(testPeerID, true) + data, has := info.getInboundHandshakeData(testPeerID) require.True(t, has) require.True(t, data.received) require.False(t, data.validated) @@ -227,7 +229,7 @@ func TestCreateNotificationsMessageHandler_BlockAnnounceHandshake(t *testing.T) err = handler(stream, testHandshake) require.NoError(t, err) - data, has = info.getHandshakeData(testPeerID, true) + data, has = info.getInboundHandshakeData(testPeerID) require.True(t, has) require.True(t, data.received) require.True(t, data.validated) @@ -293,7 +295,7 @@ func Test_HandshakeTimeout(t *testing.T) { time.Sleep(time.Second) // Verify that handshake data exists. - _, ok := info.getHandshakeData(nodeB.host.id(), false) + _, ok := info.getOutboundHandshakeData(nodeB.host.id()) require.True(t, ok) // a stream should be open until timeout @@ -305,7 +307,7 @@ func Test_HandshakeTimeout(t *testing.T) { time.Sleep(handshakeTimeout) // handshake data should be removed - _, ok = info.getHandshakeData(nodeB.host.id(), false) + _, ok = info.getOutboundHandshakeData(nodeB.host.id()) require.False(t, ok) // stream should be closed @@ -313,3 +315,7 @@ func Test_HandshakeTimeout(t *testing.T) { require.Len(t, connAToB, 1) require.Len(t, connAToB[0].GetStreams(), 0) } + +func TestBlockAnnounceHandshakeSize(t *testing.T) { + require.Equal(t, unsafe.Sizeof(BlockAnnounceHandshake{}), reflect.TypeOf(BlockAnnounceHandshake{}).Size()) +} diff --git a/dot/network/service.go b/dot/network/service.go index 7a64e5a243..3b9ce545b5 100644 --- a/dot/network/service.go +++ b/dot/network/service.go @@ -206,19 +206,18 @@ func (s *Service) Start() error { s.syncQueue.peerScore.Delete(p) }) - s.host.registerStreamHandler(syncID, s.handleSyncStream) - s.host.registerStreamHandler(lightID, s.handleLightStream) + s.host.registerStreamHandler(s.host.protocolID+syncID, s.handleSyncStream) + s.host.registerStreamHandler(s.host.protocolID+lightID, s.handleLightStream) // register block announce protocol err := s.RegisterNotificationsProtocol( - blockAnnounceID, + s.host.protocolID+blockAnnounceID, BlockAnnounceMsgType, s.getBlockAnnounceHandshake, decodeBlockAnnounceHandshake, s.validateBlockAnnounceHandshake, decodeBlockAnnounceMessage, s.handleBlockAnnounceMessage, - false, ) if err != nil { logger.Warn("failed to register notifications protocol", "sub-protocol", blockAnnounceID, "error", err) @@ -226,14 +225,13 @@ func (s *Service) Start() error { // register transactions protocol err = s.RegisterNotificationsProtocol( - transactionsID, + s.host.protocolID+transactionsID, TransactionMsgType, s.getTransactionHandshake, decodeTransactionHandshake, validateTransactionHandshake, decodeTransactionMessage, s.handleTransactionMessage, - false, ) if err != nil { logger.Warn("failed to register notifications protocol", "sub-protocol", blockAnnounceID, "error", err) @@ -414,14 +412,14 @@ mainloop: // RegisterNotificationsProtocol registers a protocol with the network service with the given handler // messageID is a user-defined message ID for the message passed over this protocol. -func (s *Service) RegisterNotificationsProtocol(sub protocol.ID, +func (s *Service) RegisterNotificationsProtocol( + protocolID protocol.ID, messageID byte, handshakeGetter HandshakeGetter, handshakeDecoder HandshakeDecoder, handshakeValidator HandshakeValidator, messageDecoder MessageDecoder, messageHandler NotificationsMessageHandler, - overwriteProtocol bool, ) error { s.notificationsMu.Lock() defer s.notificationsMu.Unlock() @@ -430,13 +428,6 @@ func (s *Service) RegisterNotificationsProtocol(sub protocol.ID, return errors.New("notifications protocol with message type already exists") } - var protocolID protocol.ID - if overwriteProtocol { - protocolID = sub - } else { - protocolID = s.host.protocolID + sub - } - np := ¬ificationsProtocol{ protocolID: protocolID, getHandshake: handshakeGetter, @@ -449,7 +440,7 @@ func (s *Service) RegisterNotificationsProtocol(sub protocol.ID, connMgr := s.host.h.ConnManager().(*ConnManager) connMgr.registerCloseHandler(protocolID, func(peerID peer.ID) { - if _, ok := np.getHandshakeData(peerID, true); ok { + if _, ok := np.getInboundHandshakeData(peerID); ok { logger.Trace( "Cleaning up inbound handshake data", "peer", peerID, @@ -458,7 +449,7 @@ func (s *Service) RegisterNotificationsProtocol(sub protocol.ID, np.inboundHandshakeData.Delete(peerID) } - if _, ok := np.getHandshakeData(peerID, false); ok { + if _, ok := np.getOutboundHandshakeData(peerID); ok { logger.Trace( "Cleaning up outbound handshake data", "peer", peerID, @@ -473,8 +464,8 @@ func (s *Service) RegisterNotificationsProtocol(sub protocol.ID, decoder := createDecoder(info, handshakeDecoder, messageDecoder) handlerWithValidate := s.createNotificationsMessageHandler(info, messageHandler) - s.host.registerStreamHandlerWithOverwrite(sub, overwriteProtocol, func(stream libp2pnetwork.Stream) { - logger.Trace("received stream", "sub-protocol", sub) + s.host.registerStreamHandler(protocolID, func(stream libp2pnetwork.Stream) { + logger.Trace("received stream", "sub-protocol", protocolID) conn := stream.Conn() if conn == nil { logger.Error("Failed to get connection from stream") @@ -684,7 +675,7 @@ func (s *Service) Peers() []common.PeerInfo { s.notificationsMu.RUnlock() for _, p := range s.host.peers() { - data, has := np.getHandshakeData(p, true) + data, has := np.getInboundHandshakeData(p) if !has || data.handshake == nil { peers = append(peers, common.PeerInfo{ PeerID: p.String(), diff --git a/dot/network/service_test.go b/dot/network/service_test.go index ffa5d85564..1aec6d43a6 100644 --- a/dot/network/service_test.go +++ b/dot/network/service_test.go @@ -56,7 +56,7 @@ func createServiceHelper(t *testing.T, num int) []*Service { srvc := createTestService(t, config) srvc.noGossip = true handler := newTestStreamHandler(testBlockAnnounceMessageDecoder) - srvc.host.registerStreamHandler("", handler.handleStream) + srvc.host.registerStreamHandler(srvc.host.protocolID, handler.handleStream) srvcs = append(srvcs, srvc) } @@ -163,7 +163,7 @@ func TestBroadcastMessages(t *testing.T) { defer nodeB.Stop() nodeB.noGossip = true handler := newTestStreamHandler(testBlockAnnounceHandshakeDecoder) - nodeB.host.registerStreamHandler(blockAnnounceID, handler.handleStream) + nodeB.host.registerStreamHandler(nodeB.host.protocolID+blockAnnounceID, handler.handleStream) addrInfoB := nodeB.host.addrInfo() err := nodeA.host.connect(addrInfoB) @@ -208,7 +208,7 @@ func TestBroadcastDuplicateMessage(t *testing.T) { nodeB.noGossip = true handler := newTestStreamHandler(testBlockAnnounceHandshakeDecoder) - nodeB.host.registerStreamHandler(blockAnnounceID, handler.handleStream) + nodeB.host.registerStreamHandler(nodeB.host.protocolID+blockAnnounceID, handler.handleStream) addrInfoB := nodeB.host.addrInfo() err := nodeA.host.connect(addrInfoB) diff --git a/lib/grandpa/network.go b/lib/grandpa/network.go index ee9e19781e..fedd4597a7 100644 --- a/lib/grandpa/network.go +++ b/lib/grandpa/network.go @@ -53,7 +53,7 @@ type GrandpaHandshake struct { //nolint } // SubProtocol returns the grandpa sub-protocol -func (hs *GrandpaHandshake) SubProtocol() string { +func (*GrandpaHandshake) SubProtocol() string { return string(grandpaID) } @@ -79,29 +79,29 @@ func (hs *GrandpaHandshake) Decode(in []byte) error { } // Type ... -func (hs *GrandpaHandshake) Type() byte { +func (*GrandpaHandshake) Type() byte { return 0 } // Hash ... -func (hs *GrandpaHandshake) Hash() common.Hash { +func (*GrandpaHandshake) Hash() common.Hash { return common.Hash{} } // IsHandshake returns true -func (hs *GrandpaHandshake) IsHandshake() bool { +func (*GrandpaHandshake) IsHandshake() bool { return true } func (s *Service) registerProtocol() error { - return s.network.RegisterNotificationsProtocol(grandpaID, + return s.network.RegisterNotificationsProtocol( + grandpaID, messageID, s.getHandshake, s.decodeHandshake, s.validateHandshake, s.decodeMessage, s.handleNetworkMessage, - true, ) } @@ -119,17 +119,17 @@ func (s *Service) getHandshake() (Handshake, error) { }, nil } -func (s *Service) decodeHandshake(in []byte) (Handshake, error) { +func (*Service) decodeHandshake(in []byte) (Handshake, error) { hs := new(GrandpaHandshake) err := hs.Decode(in) return hs, err } -func (s *Service) validateHandshake(_ peer.ID, _ Handshake) error { +func (*Service) validateHandshake(_ peer.ID, _ Handshake) error { return nil } -func (s *Service) decodeMessage(in []byte) (NotificationsMessage, error) { +func (*Service) decodeMessage(in []byte) (NotificationsMessage, error) { msg := new(network.ConsensusMessage) err := msg.Decode(in) return msg, err diff --git a/lib/grandpa/round_test.go b/lib/grandpa/round_test.go index 8a55810f27..cf0bc89b4a 100644 --- a/lib/grandpa/round_test.go +++ b/lib/grandpa/round_test.go @@ -82,14 +82,14 @@ func (n *testNetwork) SendJustificationRequest(to peer.ID, num uint32) { } } -func (n *testNetwork) RegisterNotificationsProtocol(sub protocol.ID, +func (n *testNetwork) RegisterNotificationsProtocol( + pid protocol.ID, messageID byte, handshakeGetter network.HandshakeGetter, handshakeDecoder network.HandshakeDecoder, handshakeValidator network.HandshakeValidator, messageDecoder network.MessageDecoder, messageHandler network.NotificationsMessageHandler, - overwriteProtocol bool, ) error { return nil } diff --git a/lib/grandpa/state.go b/lib/grandpa/state.go index 14bcbbab44..d35c99021c 100644 --- a/lib/grandpa/state.go +++ b/lib/grandpa/state.go @@ -84,6 +84,5 @@ type Network interface { handshakeValidator network.HandshakeValidator, messageDecoder network.MessageDecoder, messageHandler network.NotificationsMessageHandler, - overwriteProtocol bool, ) error }