diff --git a/raft/node.go b/raft/node.go index f3ba250b9af9..ccf538acee46 100644 --- a/raft/node.go +++ b/raft/node.go @@ -118,6 +118,10 @@ type Node interface { Campaign(ctx context.Context) error // Propose proposes that data be appended to the log. Propose(ctx context.Context, data []byte) error + // ProposeWithCancel proposes that data be appended to the log + // and will invoke the cancel() if the proposal is dropped. + // Application can use this to fail fast when the proposal is dropped or cancelled. + ProposeWithCancel(ctx context.Context, cancel context.CancelFunc, data []byte) error // ProposeConfChange proposes config change. // At most one ConfChange can be in the process of going through consensus. // Application needs to call ApplyConfChange when applying EntryConfChange type entry. @@ -224,10 +228,15 @@ func RestartNode(c *Config) Node { return &n } +type msgWithCancel struct { + m pb.Message + cancel context.CancelFunc +} + // node is the canonical implementation of the Node interface type node struct { - propc chan pb.Message - recvc chan pb.Message + propc chan msgWithCancel + recvc chan msgWithCancel confc chan pb.ConfChange confstatec chan pb.ConfState readyc chan Ready @@ -242,8 +251,8 @@ type node struct { func newNode() node { return node{ - propc: make(chan pb.Message), - recvc: make(chan pb.Message), + propc: make(chan msgWithCancel), + recvc: make(chan msgWithCancel), confc: make(chan pb.ConfChange), confstatec: make(chan pb.ConfState), readyc: make(chan Ready), @@ -271,7 +280,7 @@ func (n *node) Stop() { } func (n *node) run(r *raft) { - var propc chan pb.Message + var propc chan msgWithCancel var readyc chan Ready var advancec chan struct{} var prevLastUnstablei, prevLastUnstablet uint64 @@ -314,13 +323,21 @@ func (n *node) run(r *raft) { // TODO: maybe buffer the config propose if there exists one (the way // described in raft dissertation) // Currently it is dropped in Step silently. - case m := <-propc: + case mc := <-propc: + m := mc.m m.From = r.id - r.Step(m) - case m := <-n.recvc: + err := r.Step(m) + if err == ErrProposalDropped && mc.cancel != nil { + mc.cancel() + } + case mc := <-n.recvc: + m := mc.m // filter out response message from unknown From. if pr := r.getProgress(m.From); pr != nil || !IsResponseMsg(m.Type) { - r.Step(m) // raft never returns an error + err := r.Step(m) // raft never returns an error + if err == ErrProposalDropped && mc.cancel != nil { + mc.cancel() + } } case cc := <-n.confc: if cc.NodeID == None { @@ -411,6 +428,10 @@ func (n *node) Propose(ctx context.Context, data []byte) error { return n.step(ctx, pb.Message{Type: pb.MsgProp, Entries: []pb.Entry{{Data: data}}}) } +func (n *node) ProposeWithCancel(ctx context.Context, cancel context.CancelFunc, data []byte) error { + return n.stepWithCancel(ctx, cancel, pb.Message{Type: pb.MsgProp, Entries: []pb.Entry{{Data: data}}}) +} + func (n *node) Step(ctx context.Context, m pb.Message) error { // ignore unexpected local messages receiving over network if IsLocalMsg(m.Type) { @@ -428,16 +449,15 @@ func (n *node) ProposeConfChange(ctx context.Context, cc pb.ConfChange) error { return n.Step(ctx, pb.Message{Type: pb.MsgProp, Entries: []pb.Entry{{Type: pb.EntryConfChange, Data: data}}}) } -// Step advances the state machine using msgs. The ctx.Err() will be returned, -// if any. -func (n *node) step(ctx context.Context, m pb.Message) error { +func (n *node) stepWithCancel(ctx context.Context, cancel context.CancelFunc, m pb.Message) error { ch := n.recvc if m.Type == pb.MsgProp { ch = n.propc } + mc := msgWithCancel{m: m, cancel: cancel} select { - case ch <- m: + case ch <- mc: return nil case <-ctx.Done(): return ctx.Err() @@ -446,6 +466,12 @@ func (n *node) step(ctx context.Context, m pb.Message) error { } } +// Step advances the state machine using msgs. The ctx.Err() will be returned, +// if any. +func (n *node) step(ctx context.Context, m pb.Message) error { + return n.stepWithCancel(ctx, nil, m) +} + func (n *node) Ready() <-chan Ready { return n.readyc } func (n *node) Advance() { @@ -480,7 +506,7 @@ func (n *node) Status() Status { func (n *node) ReportUnreachable(id uint64) { select { - case n.recvc <- pb.Message{Type: pb.MsgUnreachable, From: id}: + case n.recvc <- msgWithCancel{m: pb.Message{Type: pb.MsgUnreachable, From: id}, cancel: nil}: case <-n.done: } } @@ -489,7 +515,7 @@ func (n *node) ReportSnapshot(id uint64, status SnapshotStatus) { rej := status == SnapshotFailure select { - case n.recvc <- pb.Message{Type: pb.MsgSnapStatus, From: id, Reject: rej}: + case n.recvc <- msgWithCancel{m: pb.Message{Type: pb.MsgSnapStatus, From: id, Reject: rej}, cancel: nil}: case <-n.done: } } @@ -497,7 +523,7 @@ func (n *node) ReportSnapshot(id uint64, status SnapshotStatus) { func (n *node) TransferLeadership(ctx context.Context, lead, transferee uint64) { select { // manually set 'from' and 'to', so that leader can voluntarily transfers its leadership - case n.recvc <- pb.Message{Type: pb.MsgTransferLeader, From: transferee, To: lead}: + case n.recvc <- msgWithCancel{m: pb.Message{Type: pb.MsgTransferLeader, From: transferee, To: lead}, cancel: nil}: case <-n.done: case <-ctx.Done(): } diff --git a/raft/node_test.go b/raft/node_test.go index f884f3319a5f..2566923761f1 100644 --- a/raft/node_test.go +++ b/raft/node_test.go @@ -18,6 +18,7 @@ import ( "bytes" "context" "reflect" + "strings" "testing" "time" @@ -30,8 +31,8 @@ import ( func TestNodeStep(t *testing.T) { for i, msgn := range raftpb.MessageType_name { n := &node{ - propc: make(chan raftpb.Message, 1), - recvc: make(chan raftpb.Message, 1), + propc: make(chan msgWithCancel, 1), + recvc: make(chan msgWithCancel, 1), } msgt := raftpb.MessageType(i) n.Step(context.TODO(), raftpb.Message{Type: msgt}) @@ -64,7 +65,7 @@ func TestNodeStep(t *testing.T) { func TestNodeStepUnblock(t *testing.T) { // a node without buffer to block step n := &node{ - propc: make(chan raftpb.Message), + propc: make(chan msgWithCancel), done: make(chan struct{}), } @@ -433,6 +434,165 @@ func TestBlockProposal(t *testing.T) { } } +func TestNodeProposeWithCancelNormal(t *testing.T) { + msgs := []raftpb.Message{} + normalMsg := []byte("normal_message") + normalDoneCh := make(chan bool) + appendStep := func(r *raft, m raftpb.Message) error { + if m.Type == raftpb.MsgProp && strings.Contains(m.String(), string(normalMsg)) { + close(normalDoneCh) + } + msgs = append(msgs, m) + return nil + } + + n := newNode() + s := NewMemoryStorage() + r := newTestRaft(1, []uint64{1}, 10, 1, s) + go n.run(r) + n.Campaign(context.TODO()) + for { + rd := <-n.Ready() + s.Append(rd.Entries) + // change the step function to dropStep until this raft becomes leader + if rd.SoftState.Lead == r.id { + r.step = appendStep + n.Advance() + break + } + n.Advance() + } + proposalTimeout := time.Millisecond * 100 + + ctx, cancel := context.WithTimeout(context.Background(), proposalTimeout) + err := n.ProposeWithCancel(ctx, cancel, normalMsg) + if err != nil { + t.Errorf("should propose success: %v", err) + } + select { + case <-ctx.Done(): + t.Errorf("should not fail for normal proposal: %v", ctx.Err()) + case <-time.After(proposalTimeout): + t.Errorf("should return early for normal proposal") + case <-normalDoneCh: + } + cancel() + + n.Stop() + if len(msgs) != 1 { + t.Fatalf("len(msgs) = %d, want %d", len(msgs), 1) + } + if msgs[0].Type != raftpb.MsgProp { + t.Errorf("msg type = %d, want %d", msgs[0].Type, raftpb.MsgProp) + } + if !bytes.Equal(msgs[0].Entries[0].Data, normalMsg) { + t.Errorf("data = %v, want %v", msgs[0].Entries[0].Data, normalMsg) + } +} + +func TestNodeProposeWithCancelDropped(t *testing.T) { + msgs := []raftpb.Message{} + droppingMsg := []byte("test_dropping") + dropStep := func(r *raft, m raftpb.Message) error { + if m.Type == raftpb.MsgProp && strings.Contains(m.String(), string(droppingMsg)) { + t.Logf("dropping message: %v", m.String()) + return ErrProposalDropped + } + msgs = append(msgs, m) + return nil + } + + n := newNode() + s := NewMemoryStorage() + r := newTestRaft(1, []uint64{1}, 10, 1, s) + go n.run(r) + n.Campaign(context.TODO()) + for { + rd := <-n.Ready() + s.Append(rd.Entries) + // change the step function to dropStep until this raft becomes leader + if rd.SoftState.Lead == r.id { + r.step = dropStep + n.Advance() + break + } + n.Advance() + } + proposalTimeout := time.Millisecond * 100 + ctx, cancel := context.WithTimeout(context.Background(), proposalTimeout) + // propose with cancel should be cancelled earyly if dropped + err := n.ProposeWithCancel(ctx, cancel, droppingMsg) + if err != nil { + t.Errorf("should propose success: %v", err) + } + select { + case <-ctx.Done(): + if ctx.Err() != context.Canceled { + t.Errorf("should cancel propose for dropped proposal with cancel") + } + case <-time.After(proposalTimeout / 2): + t.Errorf("should return early for dropped proposal") + } + cancel() + + n.Stop() + if len(msgs) != 0 { + t.Fatalf("len(msgs) = %d, want %d", len(msgs), 1) + } +} + +func TestNodeProposeWithNoCancelDropped(t *testing.T) { + msgs := []raftpb.Message{} + droppingMsg := []byte("test_dropping") + dropStep := func(r *raft, m raftpb.Message) error { + if m.Type == raftpb.MsgProp && strings.Contains(m.String(), string(droppingMsg)) { + t.Logf("dropping message: %v", m.String()) + return ErrProposalDropped + } + msgs = append(msgs, m) + return nil + } + + n := newNode() + s := NewMemoryStorage() + r := newTestRaft(1, []uint64{1}, 10, 1, s) + go n.run(r) + n.Campaign(context.TODO()) + for { + rd := <-n.Ready() + s.Append(rd.Entries) + // change the step function to dropStep until this raft becomes leader + if rd.SoftState.Lead == r.id { + r.step = dropStep + n.Advance() + break + } + n.Advance() + } + proposalTimeout := time.Millisecond * 100 + + ctx, cancel := context.WithTimeout(context.Background(), proposalTimeout) + // normal propose should wait until timeout if dropped + err := n.Propose(ctx, droppingMsg) + if err != nil { + t.Errorf("should propose success: %v", err) + } + select { + case <-ctx.Done(): + if ctx.Err() != context.DeadlineExceeded { + t.Errorf("should timeout propose for dropped proposal with no cancel") + } + case <-time.After(proposalTimeout * 2): + t.Errorf("should return early for dropped proposal") + } + cancel() + + n.Stop() + if len(msgs) != 0 { + t.Fatalf("len(msgs) = %d, want %d", len(msgs), 1) + } +} + // TestNodeTick ensures that node.Tick() will increase the // elapsed of the underlying raft state machine. func TestNodeTick(t *testing.T) {