diff --git a/sync/handler_hello.go b/sync/handler_hello.go index 25f10c5f9..5a6972dda 100644 --- a/sync/handler_hello.go +++ b/sync/handler_hello.go @@ -78,7 +78,7 @@ func (handler *helloHandler) ParseMessage(m message.Message, pid peer.ID) error } handler.peerSet.UpdateHeight(pid, msg.Height, msg.BlockHash) - handler.peerSet.UpdateStatus(pid, peerset.StatusCodeKnown) + handler.peerSet.UpdateStatus(pid, peerset.StatusCodeConnected) response := message.NewHelloAckMessage(message.ResponseCodeOK, "Ok", handler.state.LastBlockHeight()) handler.acknowledge(response, pid) diff --git a/sync/handler_hello_test.go b/sync/handler_hello_test.go index aa980dc78..2dc88d3ad 100644 --- a/sync/handler_hello_test.go +++ b/sync/handler_hello_test.go @@ -136,7 +136,7 @@ func TestParsingHelloMessages(t *testing.T) { p := td.sync.peerSet.GetPeer(pid) pub := valKey.PublicKey() - assert.Equal(t, p.Status, peerset.StatusCodeKnown) + assert.Equal(t, p.Status, peerset.StatusCodeConnected) assert.Equal(t, p.Agent, version.NodeAgent.String()) assert.Equal(t, p.Moniker, "kitty") assert.Contains(t, p.ConsensusKeys, pub) diff --git a/sync/peerset/peer_set.go b/sync/peerset/peer_set.go index 135805737..8952a975f 100644 --- a/sync/peerset/peer_set.go +++ b/sync/peerset/peer_set.go @@ -1,7 +1,6 @@ package peerset import ( - "fmt" "sync" "time" @@ -18,9 +17,7 @@ type PeerSet struct { lk sync.RWMutex peers map[peer.ID]*Peer - sessions map[int]*session.Session - nextSessionID int - sessionTimeout time.Duration + sessionManager *session.Manager totalSentBundles int totalSentBytes int64 totalReceivedBytes int64 @@ -32,8 +29,7 @@ type PeerSet struct { func NewPeerSet(sessionTimeout time.Duration) *PeerSet { return &PeerSet{ peers: make(map[peer.ID]*Peer), - sessions: make(map[int]*session.Session), - sessionTimeout: sessionTimeout, + sessionManager: session.NewManager(sessionTimeout), sentBytes: make(map[message.Type]int64), receivedBytes: make(map[message.Type]int64), startedAt: time.Now(), @@ -44,9 +40,7 @@ func (ps *PeerSet) OpenSession(pid peer.ID, from, count uint32) *session.Session ps.lk.Lock() defer ps.lk.Unlock() - ssn := session.NewSession(ps.nextSessionID, pid, from, count) - ps.sessions[ssn.SessionID] = ssn - ps.nextSessionID++ + ssn := ps.sessionManager.OpenSession(pid, from, count) p := ps.mustGetPeer(pid) p.TotalSessions++ @@ -58,127 +52,64 @@ func (ps *PeerSet) FindSession(sid int) *session.Session { ps.lk.RLock() defer ps.lk.RUnlock() - ssn, ok := ps.sessions[sid] - if ok { - return ssn - } - - return nil + return ps.sessionManager.FindSession(sid) } func (ps *PeerSet) NumberOfSessions() int { ps.lk.Lock() defer ps.lk.Unlock() - return len(ps.sessions) + return ps.sessionManager.NumberOfSessions() } func (ps *PeerSet) HasOpenSession(pid peer.ID) bool { ps.lk.RLock() defer ps.lk.RUnlock() - for _, ssn := range ps.sessions { - if ssn.PeerID == pid && ssn.Status == session.Open { - return true - } - } - - return false -} - -type SessionStats struct { - Total int - Open int - Completed int - Uncompleted int + return ps.sessionManager.HasOpenSession(pid) } -func (ss *SessionStats) String() string { - return fmt.Sprintf("total: %v, open: %v, completed: %v, uncompleted: %v", - ss.Total, ss.Open, ss.Completed, ss.Uncompleted) -} - -func (ps *PeerSet) SessionStats() SessionStats { +func (ps *PeerSet) SessionStats() session.Stats { ps.lk.RLock() defer ps.lk.RUnlock() - total := len(ps.sessions) - open := 0 - completed := 0 - unCompleted := 0 - for _, ssn := range ps.sessions { - switch ssn.Status { - case session.Open: - open++ - - case session.Completed: - completed++ - - case session.Uncompleted: - unCompleted++ - } - } - - return SessionStats{ - Total: total, - Open: open, - Completed: completed, - Uncompleted: unCompleted, - } + return ps.sessionManager.Stats() } func (ps *PeerSet) HasAnyOpenSession() bool { ps.lk.RLock() defer ps.lk.RUnlock() - for _, ssn := range ps.sessions { - if ssn.Status == session.Open { - return true - } - } - - return false + return ps.sessionManager.HasAnyOpenSession() } func (ps *PeerSet) UpdateSessionLastActivity(sid int) { ps.lk.Lock() defer ps.lk.Unlock() - ssn := ps.sessions[sid] - if ssn != nil { - ssn.LastActivity = time.Now() - } + ps.sessionManager.UpdateSessionLastActivity(sid) } func (ps *PeerSet) SetExpiredSessionsAsUncompleted() { ps.lk.Lock() defer ps.lk.Unlock() - for _, ssn := range ps.sessions { - if ps.sessionTimeout < util.Now().Sub(ssn.LastActivity) { - ssn.Status = session.Uncompleted - } - } + ps.sessionManager.SetExpiredSessionsAsUncompleted() } func (ps *PeerSet) SetSessionUncompleted(sid int) { ps.lk.Lock() defer ps.lk.Unlock() - ssn := ps.sessions[sid] - if ssn != nil { - ssn.Status = session.Uncompleted - } + ps.sessionManager.SetSessionUncompleted(sid) } func (ps *PeerSet) SetSessionCompleted(sid int) { ps.lk.Lock() defer ps.lk.Unlock() - ssn := ps.sessions[sid] + ssn := ps.sessionManager.SetSessionCompleted(sid) if ssn != nil { - ssn.Status = session.Completed - p := ps.mustGetPeer(ssn.PeerID) p.CompletedSessions++ } @@ -188,7 +119,7 @@ func (ps *PeerSet) RemoveAllSessions() { ps.lk.Lock() defer ps.lk.Unlock() - ps.sessions = make(map[int]*session.Session) + ps.sessionManager.RemoveAllSessions() } func (ps *PeerSet) Len() int { @@ -274,7 +205,7 @@ func (ps *PeerSet) UpdateStatus(pid peer.ID, status StatusCode) { p.Status = status if status == StatusCodeDisconnected { - for _, ssn := range ps.sessions { + for _, ssn := range ps.sessionManager.Sessions() { if ssn.PeerID == pid { ssn.Status = session.Uncompleted } @@ -415,13 +346,7 @@ func (ps *PeerSet) Sessions() []*session.Session { ps.lk.RLock() defer ps.lk.RUnlock() - sessions := make([]*session.Session, 0, len(ps.sessions)) - - for _, ssn := range ps.sessions { - sessions = append(sessions, ssn) - } - - return sessions + return ps.sessionManager.Sessions() } // GetRandomPeer selects a random peer from the peer set based on their download score. diff --git a/sync/peerset/peer_set_test.go b/sync/peerset/peer_set_test.go index cd9a207e7..f06616a80 100644 --- a/sync/peerset/peer_set_test.go +++ b/sync/peerset/peer_set_test.go @@ -166,7 +166,6 @@ func TestOpenSession(t *testing.T) { assert.True(t, ps.HasOpenSession(pid)) assert.False(t, ps.HasOpenSession(ts.RandPeerID())) assert.Equal(t, 1, ps.NumberOfSessions()) - assert.Contains(t, ps.Sessions(), ssn) } func TestFindSession(t *testing.T) { diff --git a/sync/peerset/session/manager.go b/sync/peerset/session/manager.go new file mode 100644 index 000000000..651953382 --- /dev/null +++ b/sync/peerset/session/manager.go @@ -0,0 +1,133 @@ +package session + +import ( + "time" + + "github.com/libp2p/go-libp2p/core/peer" + "github.com/pactus-project/pactus/util" +) + +type Manager struct { + sessionTimeout time.Duration + sessions map[int]*Session + nextSessionID int +} + +func NewManager(sessionTimeout time.Duration) *Manager { + return &Manager{ + sessionTimeout: sessionTimeout, + sessions: make(map[int]*Session), + } +} + +func (sm *Manager) Stats() Stats { + total := len(sm.sessions) + open := 0 + completed := 0 + unCompleted := 0 + for _, ssn := range sm.sessions { + switch ssn.Status { + case Open: + open++ + + case Completed: + completed++ + + case Uncompleted: + unCompleted++ + } + } + + return Stats{ + Total: total, + Open: open, + Completed: completed, + Uncompleted: unCompleted, + } +} + +func (sm *Manager) OpenSession(pid peer.ID, from, count uint32) *Session { + ssn := NewSession(sm.nextSessionID, pid, from, count) + sm.sessions[ssn.SessionID] = ssn + sm.nextSessionID++ + + return ssn +} + +func (sm *Manager) FindSession(sid int) *Session { + ssn, ok := sm.sessions[sid] + if ok { + return ssn + } + + return nil +} + +func (sm *Manager) NumberOfSessions() int { + return len(sm.sessions) +} + +func (sm *Manager) HasOpenSession(pid peer.ID) bool { + for _, ssn := range sm.sessions { + if ssn.PeerID == pid && ssn.Status == Open { + return true + } + } + + return false +} + +func (sm *Manager) HasAnyOpenSession() bool { + for _, ssn := range sm.sessions { + if ssn.Status == Open { + return true + } + } + + return false +} + +func (sm *Manager) UpdateSessionLastActivity(sid int) { + ssn := sm.sessions[sid] + if ssn != nil { + ssn.LastActivity = time.Now() + } +} + +func (sm *Manager) SetExpiredSessionsAsUncompleted() { + for _, ssn := range sm.sessions { + if sm.sessionTimeout < util.Now().Sub(ssn.LastActivity) { + ssn.Status = Uncompleted + } + } +} + +func (sm *Manager) SetSessionUncompleted(sid int) { + ssn := sm.sessions[sid] + if ssn != nil { + ssn.Status = Uncompleted + } +} + +func (sm *Manager) SetSessionCompleted(sid int) *Session { + ssn := sm.sessions[sid] + if ssn != nil { + ssn.Status = Completed + } + + return ssn +} + +func (sm *Manager) RemoveAllSessions() { + sm.sessions = make(map[int]*Session) +} + +func (sm *Manager) Sessions() []*Session { + sessions := make([]*Session, 0, len(sm.sessions)) + + for _, ssn := range sm.sessions { + sessions = append(sessions, ssn) + } + + return sessions +} diff --git a/sync/peerset/session/stats.go b/sync/peerset/session/stats.go new file mode 100644 index 000000000..6d8f79a3e --- /dev/null +++ b/sync/peerset/session/stats.go @@ -0,0 +1,15 @@ +package session + +import "fmt" + +type Stats struct { + Total int + Open int + Completed int + Uncompleted int +} + +func (ss *Stats) String() string { + return fmt.Sprintf("total: %v, open: %v, completed: %v, uncompleted: %v", + ss.Total, ss.Open, ss.Completed, ss.Uncompleted) +}