diff --git a/raft/node.go b/raft/node.go index 33a9db840012..53d7e90dbee4 100644 --- a/raft/node.go +++ b/raft/node.go @@ -326,7 +326,9 @@ func (n *node) run(r *raft) { if cc.NodeID == None { r.resetPendingConf() select { - case n.confstatec <- pb.ConfState{Nodes: r.nodes()}: + case n.confstatec <- pb.ConfState{ + Nodes: r.nodes(), + Learners: r.learnerNodes()}: case <-n.done: } break @@ -349,7 +351,9 @@ func (n *node) run(r *raft) { panic("unexpected conf type") } select { - case n.confstatec <- pb.ConfState{Nodes: r.nodes()}: + case n.confstatec <- pb.ConfState{ + Nodes: r.nodes(), + Learners: r.learnerNodes()}: case <-n.done: } case <-n.tickc: diff --git a/raft/node_test.go b/raft/node_test.go index 4401412e7743..d1f20c4036d1 100644 --- a/raft/node_test.go +++ b/raft/node_test.go @@ -728,3 +728,54 @@ func TestIsHardStateEqual(t *testing.T) { } } } + +func TestNodeProposeAddLearnerNode(t *testing.T) { + ticker := time.NewTicker(time.Millisecond * 100) + defer ticker.Stop() + n := newNode() + s := NewMemoryStorage() + r := newTestRaft(1, []uint64{1}, 10, 1, s) + go n.run(r) + n.Campaign(context.TODO()) + stop := make(chan struct{}) + done := make(chan struct{}) + applyConfChan := make(chan struct{}) + go func() { + defer close(done) + for { + select { + case <-stop: + return + case <-ticker.C: + n.Tick() + case rd := <-n.Ready(): + s.Append(rd.Entries) + t.Logf("raft: %v", rd.Entries) + for _, ent := range rd.Entries { + if ent.Type == raftpb.EntryConfChange { + var cc raftpb.ConfChange + cc.Unmarshal(ent.Data) + state := n.ApplyConfChange(cc) + if len(state.Learners) == 0 || + state.Learners[0] != cc.NodeID || + cc.NodeID != 2 { + t.Fatalf("apply conf change should return new added learner: %v", state.String()) + } + + if len(state.Nodes) != 1 { + t.Fatalf("add learner should not change the nodes: %v", state.String()) + } + t.Logf("apply raft conf %v changed to: %v", cc, state.String()) + applyConfChan <- struct{}{} + } + } + n.Advance() + } + } + }() + cc := raftpb.ConfChange{Type: raftpb.ConfChangeAddLearnerNode, NodeID: 2} + n.ProposeConfChange(context.TODO(), cc) + <-applyConfChan + close(stop) + <-done +} diff --git a/raft/raft.go b/raft/raft.go index 1b191d9ff261..e2bbc898c537 100644 --- a/raft/raft.go +++ b/raft/raft.go @@ -372,10 +372,16 @@ func (r *raft) hardState() pb.HardState { func (r *raft) quorum() int { return len(r.prs)/2 + 1 } func (r *raft) nodes() []uint64 { - nodes := make([]uint64, 0, len(r.prs)+len(r.learnerPrs)) + nodes := make([]uint64, 0, len(r.prs)) for id := range r.prs { nodes = append(nodes, id) } + sort.Sort(uint64Slice(nodes)) + return nodes +} + +func (r *raft) learnerNodes() []uint64 { + nodes := make([]uint64, 0, len(r.learnerPrs)) for id := range r.learnerPrs { nodes = append(nodes, id) } diff --git a/raft/raft_test.go b/raft/raft_test.go index c1416d05150a..d6852de22d89 100644 --- a/raft/raft_test.go +++ b/raft/raft_test.go @@ -2475,8 +2475,12 @@ func TestRestoreWithLearner(t *testing.T) { t.Errorf("log.lastTerm = %d, want %d", mustTerm(sm.raftLog.term(s.Metadata.Index)), s.Metadata.Term) } sg := sm.nodes() - if len(sg) != len(s.Metadata.ConfState.Nodes)+len(s.Metadata.ConfState.Learners) { - t.Errorf("sm.Nodes = %+v, length not equal with %+v", sg, s.Metadata.ConfState) + if len(sg) != len(s.Metadata.ConfState.Nodes) { + t.Errorf("sm.Nodes = %+v, length not equal with %+v", sg, s.Metadata.ConfState.Nodes) + } + lns := sm.learnerNodes() + if len(lns) != len(s.Metadata.ConfState.Learners) { + t.Errorf("sm.LearnerNodes = %+v, length not equal with %+v", sg, s.Metadata.ConfState.Learners) } for _, n := range s.Metadata.ConfState.Nodes { if sm.prs[n].IsLearner { @@ -2827,8 +2831,8 @@ func TestAddLearner(t *testing.T) { if r.pendingConf { t.Errorf("pendingConf = %v, want false", r.pendingConf) } - nodes := r.nodes() - wnodes := []uint64{1, 2} + nodes := r.learnerNodes() + wnodes := []uint64{2} if !reflect.DeepEqual(nodes, wnodes) { t.Errorf("nodes = %v, want %v", nodes, wnodes) } @@ -2908,9 +2912,13 @@ func TestRemoveLearner(t *testing.T) { t.Errorf("nodes = %v, want %v", g, w) } + w = []uint64{} + if g := r.learnerNodes(); !reflect.DeepEqual(g, w) { + t.Errorf("nodes = %v, want %v", g, w) + } + // remove all nodes from cluster r.removeNode(1) - w = []uint64{} if g := r.nodes(); !reflect.DeepEqual(g, w) { t.Errorf("nodes = %v, want %v", g, w) } diff --git a/raft/rawnode.go b/raft/rawnode.go index 925cb851c4ad..8be2fa8ed086 100644 --- a/raft/rawnode.go +++ b/raft/rawnode.go @@ -170,7 +170,7 @@ func (rn *RawNode) ProposeConfChange(cc pb.ConfChange) error { func (rn *RawNode) ApplyConfChange(cc pb.ConfChange) *pb.ConfState { if cc.NodeID == None { rn.raft.resetPendingConf() - return &pb.ConfState{Nodes: rn.raft.nodes()} + return &pb.ConfState{Nodes: rn.raft.nodes(), Learners: rn.raft.learnerNodes()} } switch cc.Type { case pb.ConfChangeAddNode: @@ -184,7 +184,7 @@ func (rn *RawNode) ApplyConfChange(cc pb.ConfChange) *pb.ConfState { default: panic("unexpected conf type") } - return &pb.ConfState{Nodes: rn.raft.nodes()} + return &pb.ConfState{Nodes: rn.raft.nodes(), Learners: rn.raft.learnerNodes()} } // Step advances the state machine using the given message.