Skip to content

Commit

Permalink
Ensure installSnapshot always consumes stream. Fixes issue hashicorp#212
Browse files Browse the repository at this point in the history
  • Loading branch information
superfell committed Jun 7, 2017
1 parent 0b14ef8 commit f5aa3cb
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ test:
go test -timeout=45s ./...

integ: test
INTEG_TESTS=yes go test -timeout=5s -run=Integ ./...
INTEG_TESTS=yes go test -timeout=25s -run=Integ ./...

deps:
go get -d -v ./...
Expand Down
124 changes: 101 additions & 23 deletions integ_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,32 @@ type RaftEnv struct {
}

func (r *RaftEnv) Release() {
r.Shutdown()
os.RemoveAll(r.dir)
}

func (r *RaftEnv) Shutdown() {
r.logger.Warn("Shutdown node at %v", r.raft.serverInternals.localAddr)
f := r.raft.Shutdown()
if err := f.Error(); err != nil {
panic(err)
}
r.trans.Close()
os.RemoveAll(r.dir)
}

// Restart will start a raft node that was previously Shutdown()
func (r *RaftEnv) Restart(t *testing.T) {
trans, err := NewTCPTransport(string(r.raft.serverInternals.localAddr), nil, 2, time.Second, nil)
if err != nil {
t.Fatalf("err: %v", err)
}
r.trans = trans
r.logger.Info("Starting node at %v", trans.LocalAddr())
raft, err := NewRaft(r.conf, r.fsm, r.store, r.store, r.snapshot, r.trans)
if err != nil {
t.Fatalf("err: %v", err)
}
r.raft = raft
}

func MakeRaft(t *testing.T, conf *Config, bootstrap bool) *RaftEnv {
Expand Down Expand Up @@ -91,6 +111,7 @@ func MakeRaft(t *testing.T, conf *Config, bootstrap bool) *RaftEnv {
}

log.Info("Starting node at %v", trans.LocalAddr())
conf.Logger = env.logger
raft, err := NewRaft(conf, env.fsm, stable, stable, snap, trans)
if err != nil {
t.Fatalf("err: %v", err)
Expand Down Expand Up @@ -157,33 +178,51 @@ func NoErr(err error, t *testing.T) {
func CheckConsistent(envs []*RaftEnv, t *testing.T) {
limit := time.Now().Add(400 * time.Millisecond)
first := envs[0]
first.fsm.Lock()
defer first.fsm.Unlock()
var err error
CHECK:
l1 := len(first.fsm.logs)
for i := 1; i < len(envs); i++ {
env := envs[i]
env.fsm.Lock()
l2 := len(env.fsm.logs)
if l1 != l2 {
err = fmt.Errorf("log length mismatch %d %d", l1, l2)
env.fsm.Unlock()
goto ERR
}
for idx, log := range first.fsm.logs {
other := env.fsm.logs[idx]
if bytes.Compare(log, other) != 0 {
err = fmt.Errorf("log %d mismatch %v %v", idx, log, other)
err = fmt.Errorf("log entry %d mismatch between %s/%s : '%s' / '%s'", idx, first.raft.serverInternals.localAddr, env.raft.serverInternals.localAddr, log, other)
env.fsm.Unlock()
goto ERR
}
}
env.fsm.Unlock()
}
return
ERR:
if time.Now().After(limit) {
t.Fatalf("%v", err)
}
first.fsm.Unlock()
time.Sleep(20 * time.Millisecond)
first.fsm.Lock()
goto CHECK
}

// return a log entry that's at least sz long that has the prefix 'test i '
func logBytes(i, sz int) []byte {
var logBuffer bytes.Buffer
fmt.Fprintf(&logBuffer, "test %d ", i)
for logBuffer.Len() < sz {
logBuffer.WriteByte('x')
}
return logBuffer.Bytes()
}

// Tests Raft by creating a cluster, growing it to 5 nodes while
// causing various stressful conditions
func TestRaft_Integ(t *testing.T) {
Expand All @@ -201,13 +240,21 @@ func TestRaft_Integ(t *testing.T) {
NoErr(WaitFor(env1, Leader), t)

// Do some commits
var futures []Future
for i := 0; i < 100; i++ {
futures = append(futures, env1.raft.Apply([]byte(fmt.Sprintf("test%d", i)), 0))
}
for _, f := range futures {
NoErr(WaitFuture(f, t), t)
totalApplied := 0
applyAndWait := func(leader *RaftEnv, n, sz int) {
// Do some commits
var futures []ApplyFuture
for i := 0; i < n; i++ {
futures = append(futures, leader.raft.Apply(logBytes(i, sz), 0))
}
for _, f := range futures {
NoErr(WaitFuture(f, t), t)
leader.logger.Debug("Applied at %d, size %d", f.Index(), sz)
}
totalApplied += n
}
// Do some commits
applyAndWait(env1, 100, 10)

// Do a snapshot
NoErr(WaitFuture(env1.raft.Snapshot(), t), t)
Expand All @@ -227,13 +274,46 @@ func TestRaft_Integ(t *testing.T) {
NoErr(err, t)

// Do some more commits
futures = nil
for i := 0; i < 100; i++ {
futures = append(futures, leader.raft.Apply([]byte(fmt.Sprintf("test%d", i)), 0))
}
for _, f := range futures {
NoErr(WaitFuture(f, t), t)
applyAndWait(leader, 100, 10)

// Snapshot the leader
NoErr(WaitFuture(leader.raft.Snapshot(), t), t)
CheckConsistent(append([]*RaftEnv{env1}, envs...), t)

// shutdown a follower
disconnected := envs[len(envs)-1]
disconnected.Shutdown()

// Do some more commits [make sure the resulting snapshot will be a reasonable size]
applyAndWait(leader, 100, 10000)

// snapshot the leader [leaders log should be compacted past the disconnected follower log now]
NoErr(WaitFuture(leader.raft.Snapshot(), t), t)

// Unfortuantly we need to wait for the leader to start backing off RPCs to the down follower
// such that when the follower comes back up it'll run an election before it gets an rpc from
// the leader
time.Sleep(time.Second * 5)

// start the now out of date follower back up
disconnected.Restart(t)

// wait for it to get caught up
timeout := time.Now().Add(time.Second * 10)
for {
dsf := disconnected.raft.Stats()
lsf := leader.raft.Stats()
WaitFuture(dsf, t)
WaitFuture(lsf, t)
if dsf.Stats().AppliedIndex == lsf.Stats().AppliedIndex {
break
}
time.Sleep(time.Millisecond)
if time.Now().After(timeout) {
t.Fatalf("Gave up waiting for follower to get caught up to leader")
}
}
CheckConsistent(append([]*RaftEnv{env1}, envs...), t)

// Shoot two nodes in the head!
rm1, rm2 := envs[0], envs[1]
Expand All @@ -247,13 +327,7 @@ func TestRaft_Integ(t *testing.T) {
NoErr(err, t)

// Do some more commits
futures = nil
for i := 0; i < 100; i++ {
futures = append(futures, leader.raft.Apply([]byte(fmt.Sprintf("test%d", i)), 0))
}
for _, f := range futures {
NoErr(WaitFuture(f, t), t)
}
applyAndWait(leader, 100, 10)

// Join a few new nodes!
for i := 0; i < 2; i++ {
Expand All @@ -264,6 +338,10 @@ func TestRaft_Integ(t *testing.T) {
envs = append(envs, env)
}

// Wait for a leader
leader, err = WaitForAny(Leader, append([]*RaftEnv{env1}, envs...))
NoErr(err, t)

// Remove the old nodes
NoErr(WaitFuture(leader.raft.RemoveServer(rm1.raft.serverInternals.localID, 0, 0), t), t)
NoErr(WaitFuture(leader.raft.RemoveServer(rm2.raft.serverInternals.localID, 0, 0), t), t)
Expand All @@ -279,8 +357,8 @@ func TestRaft_Integ(t *testing.T) {
allEnvs := append([]*RaftEnv{env1}, envs...)
CheckConsistent(allEnvs, t)

if len(env1.fsm.logs) != 300 {
t.Fatalf("should apply 300 logs! %d", len(env1.fsm.logs))
if len(env1.fsm.logs) != totalApplied {
t.Fatalf("should apply %d logs! %d", totalApplied, len(env1.fsm.logs))
}

for _, e := range envs {
Expand Down
3 changes: 3 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"container/list"
"fmt"
"io"
"io/ioutil"
"os"
"sort"
"sync"
Expand Down Expand Up @@ -1594,6 +1595,7 @@ func (r *raftServer) installSnapshot(rpc RPC, req *InstallSnapshotRequest) {
}
var rpcErr error
defer func() {
io.Copy(ioutil.Discard, rpc.Reader) // ensure we always consume all the snapshot data from the stream [see issue #212]
rpc.Respond(resp, rpcErr)
}()

Expand All @@ -1606,6 +1608,7 @@ func (r *raftServer) installSnapshot(rpc RPC, req *InstallSnapshotRequest) {

// Ignore an older term
if req.Term < r.currentTerm {
r.logger.Info("Ignoring installSnapshot request with older term of %d vs currentTerm %d", req.Term, r.currentTerm)
return
}

Expand Down

0 comments on commit f5aa3cb

Please sign in to comment.