Skip to content

Commit

Permalink
Revert to older snapshots if the latest raft snapshot is corrupted
Browse files Browse the repository at this point in the history
Fix #1040
  • Loading branch information
dgnorton authored and jvshahid committed Oct 20, 2014
1 parent 7df3781 commit 5e9750f
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 56 deletions.
98 changes: 43 additions & 55 deletions _vendor/raft/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import (
"encoding/json"
"errors"
"fmt"
"hash/crc32"
"io/ioutil"
"os"
"path"
"path/filepath"
"sort"
"sync"
"time"
Expand Down Expand Up @@ -1253,6 +1253,31 @@ func (s *server) saveSnapshot() error {
return nil
}

// Returns a list of available snapshot names sorted newest to oldest
func (s *server) SnapshotList() ([]string, error) {
// Get FileInfo for everything in the snapshot dir
ssdir := path.Join(s.path, "snapshot")
finfos, err := ioutil.ReadDir(ssdir)

if err != nil {
return nil, err
}

// Build a list of snapshot file names
var ssnames []string
for _, finfo := range finfos {
fname := finfo.Name()
if finfo.Mode().IsRegular() && filepath.Ext(fname) == ".ss" {
ssnames = append(ssnames, fname)
}
}

// Sort snapshot names from newest to oldest
sort.Sort(sort.Reverse(sort.StringSlice(ssnames)))

return ssnames, nil
}

// Retrieves the log path for the server.
func (s *server) SnapshotPath(lastIndex uint64, lastTerm uint64) string {
return path.Join(s.path, "snapshot", fmt.Sprintf("%v_%v.ss", lastTerm, lastIndex))
Expand Down Expand Up @@ -1314,80 +1339,43 @@ func (s *server) processSnapshotRecoveryRequest(req *SnapshotRecoveryRequest) *S

// Load a snapshot at restart
func (s *server) LoadSnapshot() error {
// Open snapshot/ directory.
dir, err := os.OpenFile(path.Join(s.path, "snapshot"), os.O_RDONLY, 0)
if err != nil {
s.debugln("cannot.open.snapshot: ", err)
return err
}

// Retrieve a list of all snapshots.
filenames, err := dir.Readdirnames(-1)
if err != nil {
dir.Close()
panic(err)
}
dir.Close()

if len(filenames) == 0 {
s.debugln("no.snapshot.to.load")
return nil
}

// Grab the latest snapshot.
sort.Strings(filenames)
snapshotPath := path.Join(s.path, "snapshot", filenames[len(filenames)-1])

// Read snapshot data.
file, err := os.OpenFile(snapshotPath, os.O_RDONLY, 0)
sslist, err := s.SnapshotList()
if err != nil {
return err
}
defer file.Close()

// Check checksum.
var checksum uint32
n, err := fmt.Fscanf(file, "%08x\n", &checksum)
if err != nil {
return err
} else if n != 1 {
return errors.New("checksum.err: bad.snapshot.file")
// Load most recent snapshot (falling back to older snapshots if needed)
var ss *Snapshot
for _, ssname := range sslist {
ssFullPath := path.Join(s.path, "snapshot", ssname)
ss, err = loadSnapshot(ssFullPath)
if err == nil {
break
}
s.debugln(err)
}

// Load remaining snapshot contents.
b, err := ioutil.ReadAll(file)
if err != nil {
return err
return err // couldn't load any of the snapshots
}

// Generate checksum.
byteChecksum := crc32.ChecksumIEEE(b)
if uint32(checksum) != byteChecksum {
s.debugln(checksum, " ", byteChecksum)
return errors.New("bad snapshot file")
}

// Decode snapshot.
if err = json.Unmarshal(b, &s.snapshot); err != nil {
s.debugln("unmarshal.snapshot.error: ", err)
return err
}
s.snapshot = ss

// Recover snapshot into state machine.
if err = s.stateMachine.Recovery(s.snapshot.State); err != nil {
if err = s.stateMachine.Recovery(ss.State); err != nil {
s.debugln("recovery.snapshot.error: ", err)
return err
}

// Recover cluster configuration.
for _, peer := range s.snapshot.Peers {
for _, peer := range ss.Peers {
s.AddPeer(peer.Name, peer.ConnectionString)
}

// Update log state.
s.log.startTerm = s.snapshot.LastTerm
s.log.startIndex = s.snapshot.LastIndex
s.log.updateCommitIndex(s.snapshot.LastIndex)
s.log.startTerm = ss.LastTerm
s.log.startIndex = ss.LastIndex
s.log.updateCommitIndex(ss.LastIndex)

return err
}
Expand Down
51 changes: 51 additions & 0 deletions _vendor/raft/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package raft

import (
"encoding/json"
"errors"
"fmt"
"hash/crc32"
"io"
"io/ioutil"
"os"

"code.google.com/p/gogoprotobuf/proto"
"github.com/dgnorton/goback"
"github.com/influxdb/influxdb/_vendor/raft/protobuf"
)

Expand Down Expand Up @@ -51,13 +53,60 @@ type SnapshotResponse struct {
Success bool `json:"success"`
}

// loadSnapshot reads a snapshot from file
func loadSnapshot(filename string) (*Snapshot, error) {
file, err := os.OpenFile(filename, os.O_RDONLY, 0)
if err != nil {
return nil, err
}
defer file.Close()

// Check checksum.
var checksum uint32
n, err := fmt.Fscanf(file, "%08x\n", &checksum)
if err != nil {
return nil, err
} else if n != 1 {
return nil, errors.New("checksum.err: bad.snapshot.file")
}

// Load remaining snapshot contents.
b, err := ioutil.ReadAll(file)
if err != nil {
return nil, err
}

// Generate checksum.
byteChecksum := crc32.ChecksumIEEE(b)
if uint32(checksum) != byteChecksum {
return nil, errors.New("bad snapshot file")
}

// Decode snapshot.
var ss Snapshot
if err = json.Unmarshal(b, &ss); err != nil {
return nil, err
}

return &ss, nil
}

// save writes the snapshot to file.
func (ss *Snapshot) save() error {
tx := goback.Begin()
defer tx.Rollback()

// Open the file for writing.
file, err := os.OpenFile(ss.Path, os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return err
}
tx.Exec(func() error {
if err := os.Remove(ss.Path); err != nil {
panic(err)
}
return nil
})
defer file.Close()

// Serialize to JSON.
Expand All @@ -82,6 +131,8 @@ func (ss *Snapshot) save() error {
return err
}

tx.Commit()

return nil
}

Expand Down
32 changes: 32 additions & 0 deletions _vendor/raft/snapshot_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,45 @@
package raft

import (
"fmt"
"io/ioutil"
"os"
"path"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

func TestSnapshotSave(t *testing.T) {
tstPath, err := ioutil.TempDir("", "snapshot_save")
if err != nil {
t.Errorf("Failed to create a temporary directory: %s", err)
}
defer os.RemoveAll(tstPath)

index := uint64(1)
term := uint64(2)
ssfile := path.Join(tstPath, fmt.Sprintf("%v_%v.ss", term, index))

ss := &Snapshot{
LastIndex: index,
LastTerm: term,
Path: ssfile,
}

err = ss.save()

if err != nil {
t.Error(err)
}

if _, err := os.Stat(ssfile); os.IsNotExist(err) {
t.Errorf("Failed to create snapshot: %s", ssfile)
}
}

// Ensure that a snapshot occurs when there are existing logs.
func TestSnapshot(t *testing.T) {
runServerWithMockStateMachine(Leader, func(s Server, m *mock.Mock) {
Expand Down
6 changes: 5 additions & 1 deletion coordinator/raft_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,11 @@ func (s *RaftServer) startRaft() error {
}

s.raftServer.SetElectionTimeout(s.config.RaftTimeout.Duration)
s.raftServer.LoadSnapshot() // ignore errors

err = s.raftServer.LoadSnapshot()
if err != nil {
panic(err)
}

s.raftServer.AddEventListener(raft.StateChangeEventType, s.raftEventHandler)

Expand Down

0 comments on commit 5e9750f

Please sign in to comment.