diff --git a/internal/web3ext/web3ext.go b/internal/web3ext/web3ext.go index 64ceb5c42e..316aff3b38 100644 --- a/internal/web3ext/web3ext.go +++ b/internal/web3ext/web3ext.go @@ -192,6 +192,15 @@ web3._extend({ name: 'stopWS', call: 'admin_stopWS' }), + new web3._extend.Method({ + name: 'getMaxPeers', + call: 'admin_getMaxPeers' + }), + new web3._extend.Method({ + name: 'setMaxPeers', + call: 'admin_setMaxPeers', + params: 1 + }), ], properties: [ new web3._extend.Property({ diff --git a/node/api.go b/node/api.go index 1b32399f63..d838404f7d 100644 --- a/node/api.go +++ b/node/api.go @@ -61,6 +61,30 @@ type privateAdminAPI struct { node *Node // Node interfaced by this API } +// This function sets the param maxPeers for the node. If there are excess peers attached to the node, it will remove the difference. +func (api *privateAdminAPI) SetMaxPeers(maxPeers int) (bool, error) { + // Make sure the server is running, fail otherwise + server := api.node.Server() + if server == nil { + return false, ErrNodeStopped + } + + server.SetMaxPeers(maxPeers) + + return true, nil +} + +// This function gets the maxPeers param for the node. +func (api *privateAdminAPI) GetMaxPeers() (int, error) { + // Make sure the server is running, fail otherwise + server := api.node.Server() + if server == nil { + return 0, ErrNodeStopped + } + + return server.MaxPeers, nil +} + // AddPeer requests connecting to a remote node, and also maintaining the new // connection at all times, even reconnecting if it is lost. func (api *privateAdminAPI) AddPeer(url string) (bool, error) { diff --git a/p2p/server.go b/p2p/server.go index 138975e54b..7de8504bdc 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -307,6 +307,35 @@ func (srv *Server) Peers() []*Peer { return ps } +// This function retrieves the peers that are not trusted-peers +func (srv *Server) getNonTrustedPeers() []*Peer { + allPeers := srv.Peers() + + nontrustedPeers := []*Peer{} + + for _, peer := range allPeers { + if !peer.Info().Network.Trusted { + nontrustedPeers = append(nontrustedPeers, peer) + } + } + + return nontrustedPeers +} + +// SetMaxPeers sets the maximum number of peers that can be connected +func (srv *Server) SetMaxPeers(maxPeers int) { + currentPeers := srv.getNonTrustedPeers() + if len(currentPeers) > maxPeers { + peersToDrop := currentPeers[maxPeers:] + for _, peer := range peersToDrop { + log.Warn("CurrentPeers more than MaxPeers", "removing", peer.ID()) + srv.RemovePeer(peer.Node()) + } + } + + srv.MaxPeers = maxPeers +} + // PeerCount returns the number of connected peers. func (srv *Server) PeerCount() int { var count int @@ -368,6 +397,8 @@ func (srv *Server) RemoveTrustedPeer(node *enode.Node) { case srv.removetrusted <- node: case <-srv.quit: } + // Disconnect the peer if maxPeers is breached. + srv.SetMaxPeers(srv.MaxPeers) } // SubscribeEvents subscribes the given channel to peer events