From ae35eeb61f6d7b86a13c4e1770ba1e4130a9fd78 Mon Sep 17 00:00:00 2001 From: Ceyhun Onur Date: Mon, 1 Jul 2024 23:19:25 +0200 Subject: [PATCH] check router is closing in requests (#3157) Co-authored-by: Stephen Buttolph --- snow/networking/router/chain_router.go | 67 ++++++++++++ snow/networking/router/chain_router_test.go | 111 ++++++++++++++++++++ 2 files changed, 178 insertions(+) diff --git a/snow/networking/router/chain_router.go b/snow/networking/router/chain_router.go index 27bf891ab4f9..6af0984afc3f 100644 --- a/snow/networking/router/chain_router.go +++ b/snow/networking/router/chain_router.go @@ -31,6 +31,7 @@ import ( var ( errUnknownChain = errors.New("received message for unknown chain") errUnallowedNode = errors.New("received message from non-allowed node") + errClosing = errors.New("router is closing") _ Router = (*ChainRouter)(nil) _ benchlist.Benchable = (*ChainRouter)(nil) @@ -63,6 +64,7 @@ type ChainRouter struct { clock mockable.Clock log logging.Logger lock sync.Mutex + closing bool chainHandlers map[ids.ID]handler.Handler // It is only safe to call [RegisterResponse] with the router lock held. Any @@ -154,6 +156,18 @@ func (cr *ChainRouter) RegisterRequest( engineType p2p.EngineType, ) { cr.lock.Lock() + if cr.closing { + cr.log.Debug("dropping request", + zap.Stringer("nodeID", nodeID), + zap.Stringer("requestingChainID", requestingChainID), + zap.Stringer("respondingChainID", respondingChainID), + zap.Uint32("requestID", requestID), + zap.Stringer("messageOp", op), + zap.Error(errClosing), + ) + cr.lock.Unlock() + return + } // When we receive a response message type (Chits, Put, Accepted, etc.) // we validate that we actually sent the corresponding request. // Give this request a unique ID so we can do that validation. @@ -244,6 +258,17 @@ func (cr *ChainRouter) HandleInbound(ctx context.Context, msg message.InboundMes cr.lock.Lock() defer cr.lock.Unlock() + if cr.closing { + cr.log.Debug("dropping message", + zap.Stringer("messageOp", op), + zap.Stringer("nodeID", nodeID), + zap.Stringer("chainID", destinationChainID), + zap.Error(errClosing), + ) + msg.OnFinishedHandling() + return + } + // Get the chain, if it exists chain, exists := cr.chainHandlers[destinationChainID] if !exists { @@ -356,6 +381,7 @@ func (cr *ChainRouter) Shutdown(ctx context.Context) { cr.lock.Lock() prevChains := cr.chainHandlers cr.chainHandlers = map[ids.ID]handler.Handler{} + cr.closing = true cr.lock.Unlock() for _, chain := range prevChains { @@ -388,6 +414,13 @@ func (cr *ChainRouter) AddChain(ctx context.Context, chain handler.Handler) { defer cr.lock.Unlock() chainID := chain.Context().ChainID + if cr.closing { + cr.log.Debug("dropping add chain request", + zap.Stringer("chainID", chainID), + zap.Error(errClosing), + ) + return + } cr.log.Debug("registering chain with chain router", zap.Stringer("chainID", chainID), ) @@ -446,6 +479,14 @@ func (cr *ChainRouter) Connected(nodeID ids.NodeID, nodeVersion *version.Applica cr.lock.Lock() defer cr.lock.Unlock() + if cr.closing { + cr.log.Debug("dropping connected message", + zap.Stringer("nodeID", nodeID), + zap.Error(errClosing), + ) + return + } + connectedPeer, exists := cr.peers[nodeID] if !exists { connectedPeer = &peer{ @@ -493,6 +534,14 @@ func (cr *ChainRouter) Disconnected(nodeID ids.NodeID) { cr.lock.Lock() defer cr.lock.Unlock() + if cr.closing { + cr.log.Debug("dropping disconnected message", + zap.Stringer("nodeID", nodeID), + zap.Error(errClosing), + ) + return + } + peer := cr.peers[nodeID] delete(cr.peers, nodeID) if _, benched := cr.benched[nodeID]; benched { @@ -522,6 +571,15 @@ func (cr *ChainRouter) Benched(chainID ids.ID, nodeID ids.NodeID) { cr.lock.Lock() defer cr.lock.Unlock() + if cr.closing { + cr.log.Debug("dropping benched message", + zap.Stringer("nodeID", nodeID), + zap.Stringer("chainID", chainID), + zap.Error(errClosing), + ) + return + } + benchedChains, exists := cr.benched[nodeID] benchedChains.Add(chainID) cr.benched[nodeID] = benchedChains @@ -554,6 +612,15 @@ func (cr *ChainRouter) Unbenched(chainID ids.ID, nodeID ids.NodeID) { cr.lock.Lock() defer cr.lock.Unlock() + if cr.closing { + cr.log.Debug("dropping unbenched message", + zap.Stringer("nodeID", nodeID), + zap.Stringer("chainID", chainID), + zap.Error(errClosing), + ) + return + } + benchedChains := cr.benched[nodeID] benchedChains.Remove(chainID) if benchedChains.Len() != 0 { diff --git a/snow/networking/router/chain_router_test.go b/snow/networking/router/chain_router_test.go index 19b889cd2d94..9eaae3071e15 100644 --- a/snow/networking/router/chain_router_test.go +++ b/snow/networking/router/chain_router_test.go @@ -191,6 +191,117 @@ func TestShutdown(t *testing.T) { require.Less(shutdownDuration, 250*time.Millisecond) } +func TestConnectedAfterShutdownErrorLogRegression(t *testing.T) { + require := require.New(t) + + snowCtx := snowtest.Context(t, snowtest.PChainID) + chainCtx := snowtest.ConsensusContext(snowCtx) + + chainRouter := ChainRouter{} + require.NoError(chainRouter.Initialize( + ids.EmptyNodeID, + logging.NoWarn{}, // If an error log is emitted, the test will fail + nil, + time.Second, + set.Set[ids.ID]{}, + true, + set.Set[ids.ID]{}, + nil, + HealthConfig{}, + prometheus.NewRegistry(), + )) + + resourceTracker, err := tracker.NewResourceTracker( + prometheus.NewRegistry(), + resource.NoUsage, + meter.ContinuousFactory{}, + time.Second, + ) + require.NoError(err) + + p2pTracker, err := p2p.NewPeerTracker( + logging.NoLog{}, + "", + prometheus.NewRegistry(), + nil, + version.CurrentApp, + ) + require.NoError(err) + + h, err := handler.New( + chainCtx, + nil, + nil, + time.Second, + testThreadPoolSize, + resourceTracker, + validators.UnhandledSubnetConnector, + subnets.New(chainCtx.NodeID, subnets.Config{}), + commontracker.NewPeers(), + p2pTracker, + prometheus.NewRegistry(), + ) + require.NoError(err) + + engine := common.EngineTest{ + T: t, + StartF: func(context.Context, uint32) error { + return nil + }, + ContextF: func() *snow.ConsensusContext { + return chainCtx + }, + HaltF: func(context.Context) {}, + ShutdownF: func(context.Context) error { + return nil + }, + ConnectedF: func(context.Context, ids.NodeID, *version.Application) error { + return nil + }, + } + engine.Default(true) + engine.CantGossip = false + + bootstrapper := &common.BootstrapperTest{ + EngineTest: engine, + CantClear: true, + } + + h.SetEngineManager(&handler.EngineManager{ + Avalanche: &handler.Engine{ + StateSyncer: nil, + Bootstrapper: bootstrapper, + Consensus: &engine, + }, + Snowman: &handler.Engine{ + StateSyncer: nil, + Bootstrapper: bootstrapper, + Consensus: &engine, + }, + }) + chainCtx.State.Set(snow.EngineState{ + Type: engineType, + State: snow.NormalOp, // assumed bootstrapping is done + }) + + chainRouter.AddChain(context.Background(), h) + + h.Start(context.Background(), false) + + chainRouter.Shutdown(context.Background()) + + shutdownDuration, err := h.AwaitStopped(context.Background()) + require.NoError(err) + require.GreaterOrEqual(shutdownDuration, time.Duration(0)) + + // Calling connected after shutdown should result in an error log. + chainRouter.Connected( + ids.GenerateTestNodeID(), + version.CurrentApp, + ids.GenerateTestID(), + ) +} + func TestShutdownTimesOut(t *testing.T) { require := require.New(t)