diff --git a/chains/manager.go b/chains/manager.go index a1158e67716c..e764a5dfb3ce 100644 --- a/chains/manager.go +++ b/chains/manager.go @@ -859,13 +859,14 @@ func (m *manager) createAvalancheChain( // Create engine, bootstrapper and state-syncer in this order, // to make sure start callbacks are duly initialized snowmanEngineConfig := smeng.Config{ - Ctx: ctx, - AllGetsServer: snowGetHandler, - VM: vmWrappingProposerVM, - Sender: snowmanMessageSender, - Validators: vdrs, - Params: consensusParams, - Consensus: snowmanConsensus, + Ctx: ctx, + AllGetsServer: snowGetHandler, + VM: vmWrappingProposerVM, + Sender: snowmanMessageSender, + Validators: vdrs, + ConnectedValidators: connectedValidators, + Params: consensusParams, + Consensus: snowmanConsensus, } snowmanEngine, err := smeng.New(snowmanEngineConfig) if err != nil { @@ -1201,14 +1202,15 @@ func (m *manager) createSnowmanChain( // Create engine, bootstrapper and state-syncer in this order, // to make sure start callbacks are duly initialized engineConfig := smeng.Config{ - Ctx: ctx, - AllGetsServer: snowGetHandler, - VM: vm, - Sender: messageSender, - Validators: vdrs, - Params: consensusParams, - Consensus: consensus, - PartialSync: m.PartialSyncPrimaryNetwork && ctx.ChainID == constants.PlatformChainID, + Ctx: ctx, + AllGetsServer: snowGetHandler, + VM: vm, + Sender: messageSender, + Validators: vdrs, + ConnectedValidators: connectedValidators, + Params: consensusParams, + Consensus: consensus, + PartialSync: m.PartialSyncPrimaryNetwork && ctx.ChainID == constants.PlatformChainID, } engine, err := smeng.New(engineConfig) if err != nil { diff --git a/node/overridden_manager.go b/node/overridden_manager.go index 80295f8636ea..91d8c198a4c3 100644 --- a/node/overridden_manager.go +++ b/node/overridden_manager.go @@ -68,10 +68,6 @@ func (o *overriddenManager) Sample(_ ids.ID, size int) ([]ids.NodeID, error) { return o.manager.Sample(o.subnetID, size) } -func (o *overriddenManager) UniformSample(_ ids.ID, size int) ([]ids.NodeID, error) { - return o.manager.UniformSample(o.subnetID, size) -} - func (o *overriddenManager) GetMap(ids.ID) map[ids.NodeID]*validators.GetValidatorOutput { return o.manager.GetMap(o.subnetID) } diff --git a/snow/engine/common/tracker/peers.go b/snow/engine/common/tracker/peers.go index ad9592209a5a..94d653a53b1f 100644 --- a/snow/engine/common/tracker/peers.go +++ b/snow/engine/common/tracker/peers.go @@ -33,6 +33,9 @@ type Peers interface { ConnectedPercent() float64 // TotalWeight returns the total validator weight TotalWeight() uint64 + // SampleValidator returns a randomly selected connected validator. If there + // are no currently connected validators then it will return false. + SampleValidator() (ids.NodeID, bool) // PreferredPeers returns the currently connected validators. If there are // no currently connected validators then it will return the currently // connected peers. @@ -108,6 +111,13 @@ func (p *lockedPeers) TotalWeight() uint64 { return p.peers.TotalWeight() } +func (p *lockedPeers) SampleValidator() (ids.NodeID, bool) { + p.lock.RLock() + defer p.lock.RUnlock() + + return p.peers.SampleValidator() +} + func (p *lockedPeers) PreferredPeers() set.Set[ids.NodeID] { p.lock.RLock() defer p.lock.RUnlock() @@ -263,6 +273,10 @@ func (p *peerData) TotalWeight() uint64 { return p.totalWeight } +func (p *peerData) SampleValidator() (ids.NodeID, bool) { + return p.connectedValidators.Peek() +} + func (p *peerData) PreferredPeers() set.Set[ids.NodeID] { if p.connectedValidators.Len() == 0 { connectedPeers := set.NewSet[ids.NodeID](p.connectedPeers.Len()) diff --git a/snow/engine/snowman/config.go b/snow/engine/snowman/config.go index ed63af2f4936..65a24a2ea816 100644 --- a/snow/engine/snowman/config.go +++ b/snow/engine/snowman/config.go @@ -8,6 +8,7 @@ import ( "github.com/ava-labs/avalanchego/snow/consensus/snowball" "github.com/ava-labs/avalanchego/snow/consensus/snowman" "github.com/ava-labs/avalanchego/snow/engine/common" + "github.com/ava-labs/avalanchego/snow/engine/common/tracker" "github.com/ava-labs/avalanchego/snow/engine/snowman/block" "github.com/ava-labs/avalanchego/snow/validators" ) @@ -16,11 +17,12 @@ import ( type Config struct { common.AllGetsServer - Ctx *snow.ConsensusContext - VM block.ChainVM - Sender common.Sender - Validators validators.Manager - Params snowball.Parameters - Consensus snowman.Consensus - PartialSync bool + Ctx *snow.ConsensusContext + VM block.ChainVM + Sender common.Sender + Validators validators.Manager + ConnectedValidators tracker.Peers + Params snowball.Parameters + Consensus snowman.Consensus + PartialSync bool } diff --git a/snow/engine/snowman/config_test.go b/snow/engine/snowman/config_test.go index 23fc0fc39fd4..9611990d9d95 100644 --- a/snow/engine/snowman/config_test.go +++ b/snow/engine/snowman/config_test.go @@ -8,16 +8,18 @@ import ( "github.com/ava-labs/avalanchego/snow/consensus/snowball" "github.com/ava-labs/avalanchego/snow/consensus/snowman" "github.com/ava-labs/avalanchego/snow/engine/common" + "github.com/ava-labs/avalanchego/snow/engine/common/tracker" "github.com/ava-labs/avalanchego/snow/engine/snowman/block" "github.com/ava-labs/avalanchego/snow/validators" ) func DefaultConfig() Config { return Config{ - Ctx: snow.DefaultConsensusContextTest(), - VM: &block.TestVM{}, - Sender: &common.SenderTest{}, - Validators: validators.NewManager(), + Ctx: snow.DefaultConsensusContextTest(), + VM: &block.TestVM{}, + Sender: &common.SenderTest{}, + Validators: validators.NewManager(), + ConnectedValidators: tracker.NewPeers(), Params: snowball.Parameters{ K: 1, AlphaPreference: 1, diff --git a/snow/engine/snowman/transitive.go b/snow/engine/snowman/transitive.go index 7f06698cbab0..4b43dcda0acb 100644 --- a/snow/engine/snowman/transitive.go +++ b/snow/engine/snowman/transitive.go @@ -169,11 +169,10 @@ func (t *Transitive) Gossip(ctx context.Context) error { // Uniform sampling is used here to reduce bandwidth requirements of // nodes with a large amount of stake weight. - vdrIDs, err := t.Validators.UniformSample(t.Ctx.SubnetID, 1) - if err != nil { + vdrID, ok := t.ConnectedValidators.SampleValidator() + if !ok { t.Ctx.Log.Error("skipping block gossip", - zap.String("reason", "no validators"), - zap.Error(err), + zap.String("reason", "no connected validators"), ) return nil } @@ -190,9 +189,13 @@ func (t *Transitive) Gossip(ctx context.Context) error { } t.requestID++ - vdrSet := set.Of(vdrIDs...) - preferredID := t.Consensus.Preference() - t.Sender.SendPullQuery(ctx, vdrSet, t.requestID, preferredID, nextHeightToAccept) + t.Sender.SendPullQuery( + ctx, + set.Of(vdrID), + t.requestID, + t.Consensus.Preference(), + nextHeightToAccept, + ) } else { t.Ctx.Log.Debug("skipping block gossip", zap.String("reason", "blocks currently processing"), diff --git a/snow/engine/snowman/transitive_test.go b/snow/engine/snowman/transitive_test.go index 26f6c1127bbf..738f20440c58 100644 --- a/snow/engine/snowman/transitive_test.go +++ b/snow/engine/snowman/transitive_test.go @@ -22,6 +22,7 @@ import ( "github.com/ava-labs/avalanchego/snow/validators" "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/set" + "github.com/ava-labs/avalanchego/version" ) var ( @@ -41,6 +42,9 @@ func setup(t *testing.T, engCfg Config) (ids.NodeID, validators.Manager, *common vdr := ids.GenerateTestNodeID() require.NoError(vals.AddStaker(engCfg.Ctx.SubnetID, vdr, nil, ids.Empty, 1)) + require.NoError(engCfg.ConnectedValidators.Connected(context.Background(), vdr, version.CurrentApp)) + + vals.RegisterCallbackListener(engCfg.Ctx.SubnetID, engCfg.ConnectedValidators) sender := &common.SenderTest{T: t} engCfg.Sender = sender diff --git a/snow/validators/manager.go b/snow/validators/manager.go index 8cf634f29bd7..c42ea779d96b 100644 --- a/snow/validators/manager.go +++ b/snow/validators/manager.go @@ -85,10 +85,6 @@ type Manager interface { // If sampling the requested size isn't possible, an error will be returned. Sample(subnetID ids.ID, size int) ([]ids.NodeID, error) - // UniformSample returns a collection of validatorIDs in the subnet. - // If sampling the requested size isn't possible, an error will be returned. - UniformSample(subnetID ids.ID, size int) ([]ids.NodeID, error) - // Map of the validators in this subnet GetMap(subnetID ids.ID) map[ids.NodeID]*GetValidatorOutput @@ -257,21 +253,6 @@ func (m *manager) Sample(subnetID ids.ID, size int) ([]ids.NodeID, error) { return set.Sample(size) } -func (m *manager) UniformSample(subnetID ids.ID, size int) ([]ids.NodeID, error) { - if size == 0 { - return nil, nil - } - - m.lock.RLock() - set, exists := m.subnetToVdrs[subnetID] - m.lock.RUnlock() - if !exists { - return nil, ErrMissingValidators - } - - return set.UniformSample(size) -} - func (m *manager) GetMap(subnetID ids.ID) map[ids.NodeID]*GetValidatorOutput { m.lock.RLock() set, exists := m.subnetToVdrs[subnetID] diff --git a/snow/validators/set.go b/snow/validators/set.go index 564cd107153a..dfa294a70bbe 100644 --- a/snow/validators/set.go +++ b/snow/validators/set.go @@ -243,13 +243,6 @@ func (s *vdrSet) Sample(size int) ([]ids.NodeID, error) { return s.sample(size) } -func (s *vdrSet) UniformSample(size int) ([]ids.NodeID, error) { - s.lock.RLock() - defer s.lock.RUnlock() - - return s.uniformSample(size) -} - func (s *vdrSet) sample(size int) ([]ids.NodeID, error) { if !s.samplerInitialized { if err := s.sampler.Initialize(s.weights); err != nil { @@ -270,22 +263,6 @@ func (s *vdrSet) sample(size int) ([]ids.NodeID, error) { return list, nil } -func (s *vdrSet) uniformSample(size int) ([]ids.NodeID, error) { - uniform := sampler.NewUniform() - uniform.Initialize(uint64(len(s.vdrSlice))) - - indices, err := uniform.Sample(size) - if err != nil { - return nil, err - } - - list := make([]ids.NodeID, size) - for i, index := range indices { - list[i] = s.vdrSlice[index].NodeID - } - return list, nil -} - func (s *vdrSet) TotalWeight() (uint64, error) { s.lock.RLock() defer s.lock.RUnlock() diff --git a/utils/set/set.go b/utils/set/set.go index 7ab330fcc066..fd6525b6b127 100644 --- a/utils/set/set.go +++ b/utils/set/set.go @@ -184,7 +184,7 @@ func (s Set[_]) MarshalJSON() ([]byte, error) { return jsonBuf.Bytes(), errs.Err } -// Returns an element. If the set is empty, returns false +// Returns a random element. If the set is empty, returns false func (s *Set[T]) Peek() (T, bool) { for elt := range *s { return elt, true