diff --git a/raft/raft.go b/raft/raft.go index 5639fcb8f3a..b17b292a17e 100644 --- a/raft/raft.go +++ b/raft/raft.go @@ -176,9 +176,14 @@ type raft struct { heartbeatTimeout int electionTimeout int - rand *rand.Rand - tick func() - step stepFunc + // randomizedElectionTimeout is a random number between + // [electiontimeout, 2 * electiontimeout - 1]. It gets reset + // when raft changes its state to follower or candidate. + randomizedElectionTimeout int + + rand *rand.Rand + tick func() + step stepFunc logger Logger } @@ -392,6 +397,7 @@ func (r *raft) reset(term uint64) { r.electionElapsed = 0 r.heartbeatElapsed = 0 + r.resetRandomizedElectionTimeout() r.votes = make(map[uint64]bool) for id := range r.prs { @@ -422,7 +428,7 @@ func (r *raft) tickElection() { return } r.electionElapsed++ - if r.isElectionTimeout() { + if r.pastElectionTimeout() { r.electionElapsed = 0 r.Step(pb.Message{From: r.id, Type: pb.MsgHup}) } @@ -863,15 +869,15 @@ func (r *raft) loadState(state pb.HardState) { r.Vote = state.Vote } -// isElectionTimeout returns true if r.electionElapsed is greater than the -// randomized election timeout in (electiontimeout, 2 * electiontimeout - 1). +// pastElectionTimeout returns true if r.electionElapsed is greater than the +// randomized election timeout in [electiontimeout, 2 * electiontimeout - 1]. // Otherwise, it returns false. -func (r *raft) isElectionTimeout() bool { - d := r.electionElapsed - r.electionTimeout - if d < 0 { - return false - } - return d > r.rand.Int()%r.electionTimeout +func (r *raft) pastElectionTimeout() bool { + return r.electionElapsed >= r.randomizedElectionTimeout +} + +func (r *raft) resetRandomizedElectionTimeout() { + r.randomizedElectionTimeout = r.electionTimeout + r.rand.Int()%r.electionTimeout } // checkQuorumActive returns true if the quorum is active from diff --git a/raft/raft_paper_test.go b/raft/raft_paper_test.go index ed78950238a..526fb322a84 100644 --- a/raft/raft_paper_test.go +++ b/raft/raft_paper_test.go @@ -381,8 +381,8 @@ func testNonleadersElectionTimeoutNonconflict(t *testing.T, state StateType) { } } - if g := float64(conflicts) / 1000; g > 0.4 { - t.Errorf("probability of conflicts = %v, want <= 0.4", g) + if g := float64(conflicts) / 1000; g > 0.3 { + t.Errorf("probability of conflicts = %v, want <= 0.3", g) } } diff --git a/raft/raft_test.go b/raft/raft_test.go index b40053aad7a..21cd453d8fe 100644 --- a/raft/raft_test.go +++ b/raft/raft_test.go @@ -754,9 +754,10 @@ func TestIsElectionTimeout(t *testing.T) { round bool }{ {5, 0, false}, - {13, 0.3, true}, - {15, 0.5, true}, - {18, 0.8, true}, + {10, 0.1, true}, + {13, 0.4, true}, + {15, 0.6, true}, + {18, 0.9, true}, {20, 1, false}, } @@ -765,7 +766,8 @@ func TestIsElectionTimeout(t *testing.T) { sm.electionElapsed = tt.elapse c := 0 for j := 0; j < 10000; j++ { - if sm.isElectionTimeout() { + sm.resetRandomizedElectionTimeout() + if sm.pastElectionTimeout() { c++ } }